Skip to content

Commit

Permalink
use local storage instead of cookies (#189)
Browse files Browse the repository at this point in the history
* use local storage instead of cookies

* fix tests
  • Loading branch information
codekansas authored Jul 26, 2024
1 parent d384cf3 commit 9700a65
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 85 deletions.
2 changes: 1 addition & 1 deletion frontend/src/hooks/api.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export interface Robot {
}

interface GithubAuthResponse {
api_key_id: string;
api_key: string;
}

interface MeResponse {
Expand Down
13 changes: 13 additions & 0 deletions frontend/src/hooks/auth.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,19 @@ export const AuthenticationProvider = (props: AuthenticationProviderProps) => {
withCredentials: true,
});

if (id !== null) {
// Adds the API key to the request header.
api.interceptors.request.use(
(config) => {
config.headers.Authorization = `Bearer ${id}`;
return config;
},
(error) => {
return Promise.reject(error);
},
);
}

const login = useCallback((apiKeyId: string) => {
(async () => {
setLocalStorageAuth(apiKeyId);
Expand Down
2 changes: 1 addition & 1 deletion frontend/src/pages/Login.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ const Login = () => {
if (code) {
setUseSpinner(true);
const res = await auth_api.loginGithub(code as string);
auth.login(res.api_key_id);
auth.login(res.api_key);
}
})();
}, []);
Expand Down
22 changes: 14 additions & 8 deletions store/app/crud/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
GlobalSecondaryIndex = tuple[str, str, Literal["S", "N", "B"], Literal["HASH", "RANGE"]]


class ItemNotFoundError(ValueError): ...


class InternalError(RuntimeError): ...


class BaseCrud(AsyncContextManager["BaseCrud"]):
def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -73,7 +79,7 @@ async def _add_item(self, item: RobolistBaseModel, unique_fields: list[str] | No
table = await self.db.Table(TABLE_NAME)
item_data = item.model_dump()
if "type" in item_data:
raise ValueError("Cannot add item with 'type' attribute")
raise InternalError("Cannot add item with 'type' attribute")
item_data["type"] = item.__class__.__name__
condition = "attribute_not_exists(id)"
if unique_fields:
Expand Down Expand Up @@ -173,7 +179,7 @@ async def _count_items(self, item_class: type[T]) -> int:

def _validate_item(self, data: dict[str, Any], item_class: type[T]) -> T:
if (item_type := data.pop("type")) != item_class.__name__:
raise ValueError(f"Item type {str(item_type)} is not a {item_class.__name__}")
raise InternalError(f"Item type {str(item_type)} is not a {item_class.__name__}")
return item_class.model_validate(data)

@overload
Expand All @@ -187,7 +193,7 @@ async def _get_item(self, item_id: str, item_class: type[T], throw_if_missing: b
item_dict = await table.get_item(Key={"id": item_id})
if "Item" not in item_dict:
if throw_if_missing:
raise ValueError("Item not found")
raise ItemNotFoundError
return None
item_data = item_dict["Item"]
return self._validate_item(item_data, item_class)
Expand Down Expand Up @@ -256,7 +262,7 @@ async def _get_unique_item_from_secondary_index(
throw_if_missing: bool = False,
) -> T | None:
if secondary_index_name not in item_class.model_fields:
raise ValueError(f"Field '{secondary_index_name}' not in model {item_class.__name__}")
raise InternalError(f"Field '{secondary_index_name}' not in model {item_class.__name__}")
items = await self._get_items_from_secondary_index(
secondary_index,
secondary_index_name,
Expand All @@ -265,17 +271,17 @@ async def _get_unique_item_from_secondary_index(
)
if len(items) == 0:
if throw_if_missing:
raise ValueError(f"No items found with {secondary_index_name} {secondary_index_value}")
raise InternalError(f"No items found with {secondary_index_name} {secondary_index_value}")
return None
if len(items) > 1:
raise ValueError(f"Multiple items found with {secondary_index_name} {secondary_index_value}")
raise InternalError(f"Multiple items found with {secondary_index_name} {secondary_index_value}")
return items[0]

async def _update_item(self, item_id: str, item_class: type[T], new_values: dict[str, Any]) -> None: # noqa: ANN401
# Validates the new values.
for field_name, field_value in new_values.items():
for field_name in new_values.keys():
if item_class.model_fields.get(field_name) is None:
raise ValueError(f"Field {field_name} not in model {item_class.__name__}")
raise InternalError(f"Field {field_name} not in model {item_class.__name__}")

# Updates the table.
table = await self.db.Table(TABLE_NAME)
Expand Down
27 changes: 26 additions & 1 deletion store/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse

from store.app.crud.base import InternalError, ItemNotFoundError
from store.app.db import create_tables
from store.app.routers.image import image_router
from store.app.routers.part import parts_router
from store.app.routers.robot import robots_router
from store.app.routers.users import users_router
from store.app.routers.users import NotAuthenticatedError, users_router
from store.settings import settings

LOCALHOST_URLS = [
Expand Down Expand Up @@ -60,6 +61,30 @@ async def value_error_exception_handler(request: Request, exc: ValueError) -> JS
)


@app.exception_handler(ItemNotFoundError)
async def item_not_found_exception_handler(request: Request, exc: ItemNotFoundError) -> JSONResponse:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"message": "Item not found.", "detail": str(exc)},
)


@app.exception_handler(InternalError)
async def internal_error_exception_handler(request: Request, exc: InternalError) -> JSONResponse:
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"message": "Internal error.", "detail": str(exc)},
)


@app.exception_handler(NotAuthenticatedError)
async def not_authenticated_exception_handler(request: Request, exc: NotAuthenticatedError) -> JSONResponse:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"message": "Not authenticated.", "detail": str(exc)},
)


@app.get("/")
async def read_root() -> bool:
return True
Expand Down
6 changes: 2 additions & 4 deletions store/app/routers/auth/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def github_email_req(headers: dict[str, str]) -> HttpxResponse:


class GithubAuthResponse(BaseModel):
api_key_id: str
api_key: str


@github_auth_router.get("/code/{code}", response_model=GithubAuthResponse)
Expand Down Expand Up @@ -94,6 +94,4 @@ async def github_code(
permissions="full", # OAuth tokens have full permissions.
)

response.set_cookie(key="session_token", value=api_key.id, httponly=True, samesite="lax")

return GithubAuthResponse(api_key_id=api_key.id)
return GithubAuthResponse(api_key=api_key.id)
82 changes: 42 additions & 40 deletions store/app/routers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi.security.utils import get_authorization_scheme_param
from pydantic.main import BaseModel as PydanticBaseModel

from store.app.crud.base import ItemNotFoundError
from store.app.db import Crud
from store.app.model import User, UserPermission
from store.app.routers.auth.github import github_auth_router
Expand All @@ -20,73 +21,75 @@
TOKEN_TYPE = "Bearer"


class NotAuthenticatedError(Exception): ...


class BaseModel(PydanticBaseModel):
class Config:
arbitrary_types_allowed = True


def set_token_cookie(response: Response, token: str, key: str) -> None:
response.set_cookie(
key=key,
value=token,
httponly=True,
secure=False,
samesite="lax",
)


async def get_request_api_key_id(request: Request) -> str:
api_key_id = request.cookies.get("session_token")
if not api_key_id:
authorization = request.headers.get("Authorization") or request.headers.get("authorization")
if authorization:
scheme, credentials = get_authorization_scheme_param(authorization)
if not (scheme and credentials):
raise HTTPException(
status_code=status.HTTP_406_NOT_ACCEPTABLE,
detail="Authorization header is invalid",
)
if scheme.lower() != TOKEN_TYPE.lower():
raise HTTPException(
status_code=status.HTTP_406_NOT_ACCEPTABLE,
detail="Authorization scheme is invalid",
)
return credentials
authorization = request.headers.get("Authorization") or request.headers.get("authorization")
if not authorization:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
)
return api_key_id
scheme, credentials = get_authorization_scheme_param(authorization)
if not (scheme and credentials):
raise HTTPException(
status_code=status.HTTP_406_NOT_ACCEPTABLE,
detail="Authorization header is invalid",
)
if scheme.lower() != TOKEN_TYPE.lower():
raise HTTPException(
status_code=status.HTTP_406_NOT_ACCEPTABLE,
detail="Authorization scheme is invalid",
)
return credentials


async def get_session_user_with_read_permission(
crud: Annotated[Crud, Depends(Crud.get)],
api_key_id: Annotated[str, Depends(get_request_api_key_id)],
) -> User:
api_key = await crud.get_api_key(api_key_id)
if api_key.permissions is None or "read" not in api_key.permissions:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Permission denied")
return await crud.get_user(api_key.user_id, throw_if_missing=True)
try:
api_key = await crud.get_api_key(api_key_id)
if api_key.permissions is None or "read" not in api_key.permissions:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Permission denied")
try:
return await crud.get_user(api_key.user_id, throw_if_missing=True)
except ItemNotFoundError:
raise NotAuthenticatedError("Not authenticated")
except ItemNotFoundError:
raise NotAuthenticatedError("Not authenticated")


async def get_session_user_with_write_permission(
crud: Annotated[Crud, Depends(Crud.get)],
api_key_id: Annotated[str, Depends(get_request_api_key_id)],
) -> User:
api_key = await crud.get_api_key(api_key_id)
if api_key.permissions is None or "write" not in api_key.permissions:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Permission denied")
return await crud.get_user(api_key.user_id, throw_if_missing=True)
try:
api_key = await crud.get_api_key(api_key_id)
if api_key.permissions is None or "write" not in api_key.permissions:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Permission denied")
return await crud.get_user(api_key.user_id, throw_if_missing=True)
except ItemNotFoundError:
raise NotAuthenticatedError("Not authenticated")


async def get_session_user_with_admin_permission(
crud: Annotated[Crud, Depends(Crud.get)],
api_key_id: Annotated[str, Depends(get_request_api_key_id)],
) -> User:
api_key = await crud.get_api_key(api_key_id)
if api_key.permissions is None or "admin" not in api_key.permissions:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Permission denied")
return await crud.get_user(api_key.user_id, throw_if_missing=True)
try:
api_key = await crud.get_api_key(api_key_id)
if api_key.permissions is None or "admin" not in api_key.permissions:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Permission denied")
return await crud.get_user(api_key.user_id, throw_if_missing=True)
except ItemNotFoundError:
raise NotAuthenticatedError("Not authenticated")


def validate_email(email: str) -> str:
Expand Down Expand Up @@ -140,7 +143,6 @@ async def logout_user_endpoint(
response: Response,
) -> bool:
await crud.delete_api_key(token)
response.delete_cookie("session_token")
return True


Expand Down
1 change: 0 additions & 1 deletion tests/test_robots.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ async def test_robots(app_client: AsyncClient) -> None:
# Register.
response = await app_client.get("/users/github/code/doesnt-matter")
assert response.status_code == 200, response.json()
assert "session_token" in response.cookies

# Create a part.
response = await app_client.post(
Expand Down
50 changes: 21 additions & 29 deletions tests/test_users.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Runs tests on the user APIs."""

from fastapi import status
from httpx import AsyncClient

from store.app.db import create_tables
Expand All @@ -10,55 +11,46 @@ async def test_user_auth_functions(app_client: AsyncClient) -> None:

# Checks that without the session token we get a 401 response.
response = await app_client.get("/users/me")
assert response.status_code == 401, response.json()
assert response.status_code == status.HTTP_401_UNAUTHORIZED, response.json()
assert response.json()["detail"] == "Not authenticated"

# Checks that we can't log the user out without the session token.
response = await app_client.delete("/users/logout")
assert response.status_code == 401, response.json()
assert response.status_code == status.HTTP_401_UNAUTHORIZED, response.json()

# Because of the way we patched GitHub functions for mocking, it doesn't matter what token we pass in.
response = await app_client.get("/users/github/code/doesnt-matter")
assert response.status_code == 200, response.json()
assert "session_token" in response.cookies
token = response.cookies["session_token"]
assert token == response.json()["api_key_id"]
assert response.status_code == status.HTTP_200_OK, response.json()
token = response.json()["api_key"]
auth_headers = {"Authorization": f"Bearer {token}"}

# Checks that with the session token we get a 200 response.
response = await app_client.get("/users/me")
assert response.status_code == 200, response.json()
user_id = response.json()["user_id"]

# Use the Authorization header instead of the cookie.
response = await app_client.get(
"/users/me",
cookies={"session_token": ""},
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200, response.json()
assert response.json()["user_id"] == user_id
response = await app_client.get("/users/me", headers=auth_headers)
assert response.status_code == status.HTTP_200_OK, response.json()

# Log the user out, which deletes the session token.
response = await app_client.delete("/users/logout")
assert response.status_code == 200, response.json()
response = await app_client.delete("/users/logout", headers=auth_headers)
assert response.status_code == status.HTTP_200_OK, response.json()
assert response.json() is True

# Checks that we can no longer use that session token to get the user's info.
response = await app_client.get("/users/me")
assert response.status_code == 401, response.json()
response = await app_client.get("/users/me", headers=auth_headers)
assert response.status_code == status.HTTP_401_UNAUTHORIZED, response.json()
assert response.json()["detail"] == "Not authenticated"

# Log the user back in, getting new session token.
response = await app_client.get("/users/github/code/doesnt-matter")
assert response.status_code == 200, response.json()
assert "session_token" in response.cookies
assert response.status_code == status.HTTP_200_OK, response.json()
assert response.json()["api_key"] != token
token = response.json()["api_key"]
auth_headers = {"Authorization": f"Bearer {token}"}

# Delete the user using the new session token.
response = await app_client.delete("/users/me")
assert response.status_code == 200, response.json()
response = await app_client.delete("/users/me", headers=auth_headers)
assert response.status_code == status.HTTP_200_OK, response.json()
assert response.json() is True

# Tries deleting the user again, which should fail.
response = await app_client.delete("/users/me")
assert response.status_code == 400, response.json()
assert response.json()["detail"] == "Item not found"
response = await app_client.delete("/users/me", headers=auth_headers)
assert response.status_code == status.HTTP_401_UNAUTHORIZED, response.json()
assert response.json()["detail"] == "Not authenticated"

0 comments on commit 9700a65

Please sign in to comment.