Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Robot table #19

Closed
wants to merge 11 commits into from
Closed
29 changes: 26 additions & 3 deletions store/app/api/crud/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Defines the base CRUD interface."""

import itertools
from typing import Any, AsyncContextManager, Literal, Self

import aioboto3
Expand Down Expand Up @@ -29,16 +30,38 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: #
if self.__db is not None:
await self.__db.__aexit__(exc_type, exc_val, exc_tb)

"""Creates a table in the Dynamo database.

Args:
name: Name of the table.
keys: Primary and secondary keys. Do not include non-key attributes.
gsis: Making an attribute a GSI is required in oredr to query against it.
Note HASH on a GSI does not actually enforce uniqueness.
Instead, the difference is: you cannot query RANGE fields alone but you may query HASH fields
deletion_protection: Whether the table is protected from being deleted.
"""
async def _create_dynamodb_table(
self,
name: str,
columns: list[tuple[str, Literal["S", "N", "B"], Literal["HASH", "RANGE"]]],
keys: list[tuple[str, Literal["S", "N", "B"], Literal["HASH", "RANGE"]]],
gsis: list[tuple[str, str, Literal["S", "N", "B"], Literal["HASH", "RANGE"]]] = [],
deletion_protection: bool = False,
) -> None:
table = await self.db.create_table(
AttributeDefinitions=[{"AttributeName": n, "AttributeType": t} for n, t, _ in columns],
AttributeDefinitions=[
{"AttributeName": n, "AttributeType": t}
for n, t in itertools.chain(((n, t) for (n, t, _) in keys), ((n, t) for _, n, t, _ in gsis))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

beautiful :)

],
TableName=name,
KeySchema=[{"AttributeName": n, "KeyType": t} for n, _, t in columns],
KeySchema=[{"AttributeName": n, "KeyType": t} for n, _, t in keys],
GlobalSecondaryIndexes=[
{
"IndexName": i,
"KeySchema": [{"AttributeName": n, "KeyType": t}],
"Projection": {"ProjectionType": "ALL"},
}
for i, n, _, t in gsis
],
DeletionProtectionEnabled=deletion_protection,
BillingMode="PAY_PER_REQUEST",
)
Expand Down
19 changes: 11 additions & 8 deletions store/app/api/crud/users.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Defines CRUD interface for user API."""

import asyncio
import uuid
import warnings

from boto3.dynamodb.conditions import Key

from store.app.api.crud.base import BaseCrud
from store.app.api.model import Token, User

Expand All @@ -14,15 +17,15 @@ async def add_user(self, user: User) -> None:

async def get_user(self, email: str) -> User | None:
table = await self.db.Table("Users")
user_dict = await table.get_item(Key={"email": email})
if "Item" not in user_dict:
user_dict = await table.query(IndexName="emailIndex", KeyConditionExpression=Key("email").eq(email))
if len(user_dict["Items"]) == 0:
return None
user = User.model_validate(user_dict["Item"])
user = User.model_validate(user_dict["Items"][0])
return user

async def delete_user(self, user: User) -> None:
table = await self.db.Table("Users")
await table.delete_item(Key={"email": user.email})
await table.delete_item(Key={"id": user.id})

async def list_users(self) -> list[User]:
warnings.warn("`list_users` probably shouldn't be called in production", ResourceWarning)
Expand All @@ -40,16 +43,16 @@ async def add_token(self, token: Token) -> None:

async def get_token(self, email: str) -> Token | None:
table = await self.db.Table("Tokens")
token_dict = await table.get_item(Key={"email": email})
if "Item" not in token_dict:
token_dict = await table.query(IndexName="emailIndex", KeyConditionExpression=Key("email").eq(email))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not get_item here? this should be a unique element right, not a list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately DynamoDB does not allow you to get_item on anything besides a primary key. It has no mechanism of enforcing uniqueness on secondary indices, which is why we have to manually check that emails are unique whenever we attempt to perform a user insertion.

Furthermore, even after we manually enforce uniqueness, DynamoDB still has no interest in just returning the first entry with the correct email it finds. There is no method to perform a singular query on a secondary index.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

never mind we will only be using email for tokens

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually tokens and the entire auth structure require significant rearch. think it is worth considering using traditional cookie based auth since it is much simpler and easier to implement. (notably i believe this implementation is wrong, though how easy it is to fix is not known to me)

if len(token_dict["Items"]) == 0:
return None
token = Token.model_validate(token_dict["Item"])
token = Token.model_validate(token_dict["Items"][0])
return token


async def test_adhoc() -> None:
async with UserCrud() as crud:
await crud.add_user(User(email="ben@kscale.dev"))
await crud.add_user(User(id=str(uuid.uuid4()), email="ben@kscale.dev"))
# print(await crud.get_user("ben"))
# print(await crud.get_user_count())
# await crud.get_token("ben")
Expand Down
23 changes: 15 additions & 8 deletions store/app/api/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,26 @@ async def create_tables(crud: Crud | None = None) -> None:
await asyncio.gather(
crud._create_dynamodb_table(
name="Users",
columns=[
("email", "S", "HASH"),
# ("banned", "B", "RANGE"),
# ("deleted", "B", "RANGE"),
keys=[
("id", "S", "HASH"),
],
gsis=[
("emailIndex", "email", "S", "HASH"),
],
),
crud._create_dynamodb_table(
name="Tokens",
columns=[
("email", "S", "HASH"),
# ("issued", "N", "RANGE"),
# ("disabled", "B", "RANGE"),
keys=[
("id", "S", "HASH"),
],
gsis=[("emailIndex", "email", "S", "HASH")],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
gsis=[("emailIndex", "email", "S", "HASH")],
gsis=[
("emailIndex", "email", "S", "HASH"),
],

reason is that if we add more indexes it keeps the file blame easier to parse

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially did that, make format is what made it all in one line. I concur, I would have written it in multiple lines, but probably best to obey the linter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you put a comma inside the list, black will format it this way

),
crud._create_dynamodb_table(
name="Robots",
keys=[
("id", "S", "HASH"),
],
gsis=[("ownerIndex", "owner", "S", "HASH")],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto here

),
)

Expand Down
4 changes: 4 additions & 0 deletions store/app/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@

class User(BaseModel):
email: str
id: str
banned: bool = field(default=False)
deleted: bool = field(default=False)


class Token(BaseModel):
# Email of the user the token belongs to
email: str
# ID of the token itself, not the user it belongs to.
id: str
issued: Decimal = field(default_factory=lambda: Decimal(datetime.datetime.now().timestamp()))
disabled: bool = field(default=False)
8 changes: 4 additions & 4 deletions store/app/api/routers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import datetime
import logging
import uuid
from email.utils import parseaddr as parse_email_address
from typing import Annotated

Expand Down Expand Up @@ -33,8 +34,7 @@ def set_token_cookie(response: Response, token: str, key: str) -> None:
value=token,
httponly=True,
secure=False,
# samesite="strict",
samesite="none",
samesite="lax",
)


Expand Down Expand Up @@ -99,15 +99,15 @@ class UserLoginResponse(BaseModel):
async def add_to_waitlist(email: str, crud: Crud) -> None:
await asyncio.gather(
send_waitlist_email(email),
crud.add_user(User(email=email, banned=True)),
crud.add_user(User(id=str(uuid.uuid4()), email=email, banned=True)),
)


async def create_or_get(email: str, crud: Crud) -> User:
# Gets or creates the user object.
user_obj = await crud.get_user(email)
if user_obj is None:
await crud.add_user(User(email=email))
await crud.add_user(User(id=str(uuid.uuid4()), email=email))
if (user_obj := await crud.get_user(email)) is None:
raise RuntimeError("Failed to add user to the database")

Expand Down
7 changes: 4 additions & 3 deletions store/app/api/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import datetime
import logging
import uuid

import jwt
from fastapi import HTTPException, status
Expand Down Expand Up @@ -72,9 +73,9 @@ async def create_refresh_token(email: str, crud: Crud) -> str:
Returns:
The encoded JWT.
"""
token = Token(email=email)
token = Token(id=str(uuid.uuid4()), email=email)
await crud.add_token(token)
return create_token({"eml": email})
return create_token({"email": email})


def load_refresh_token(payload: str) -> str:
Expand All @@ -87,4 +88,4 @@ def load_refresh_token(payload: str) -> str:
The decoded refresh token data.
"""
data = load_token(payload)
return data["eml"]
return data["email"]
Loading