Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert to Redis and add lifetime for API keys #42

Merged
merged 21 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions store/app/api/crud/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

import aioboto3
from botocore.exceptions import ClientError
from redis import Redis
codekansas marked this conversation as resolved.
Show resolved Hide resolved
from types_aiobotocore_dynamodb.service_resource import DynamoDBServiceResource

from store.settings import settings

logger = logging.getLogger(__name__)


Expand All @@ -16,6 +19,7 @@ def __init__(self) -> None:
super().__init__()

self.__db: DynamoDBServiceResource | None = None
self.__kv: Redis | None = None

@property
def db(self) -> DynamoDBServiceResource:
Expand All @@ -28,6 +32,10 @@ async def __aenter__(self) -> Self:
db = session.resource("dynamodb")
db = await db.__aenter__()
self.__db = db

self.kv = Redis(
host=settings.redis.host, password=settings.redis.password, port=settings.redis.port, db=settings.redis.db
codekansas marked this conversation as resolved.
Show resolved Hide resolved
)
return self

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # noqa: ANN401
Expand Down
27 changes: 10 additions & 17 deletions store/app/api/crud/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,11 @@ async def get_user_from_email(self, email: str) -> User | None:
return user

async def get_user_id_from_api_key(self, api_key: uuid.UUID) -> uuid.UUID | None:
table = await self.db.Table("ApiKeys")
api_key_hash = hash_api_key(api_key)
row = await table.get_item(Key={"api_key_hash": api_key_hash})
if "Item" not in row:
user_id = self.kv.get(api_key_hash)
if user_id is None:
return None
user_id = cast(str, row["Item"]["user_id"])
return uuid.UUID(user_id)
return uuid.UUID(user_id.decode('utf-8'))

async def delete_user(self, user: User) -> None:
table = await self.db.Table("Users")
Expand All @@ -59,21 +57,16 @@ async def get_user_count(self) -> int:
table = await self.db.Table("Users")
return await table.item_count

async def add_api_key(self, api_key: uuid.UUID, user_id: uuid.UUID) -> None:
row = ApiKey.from_api_key(api_key, user_id)
table = await self.db.Table("ApiKeys")
await table.put_item(Item=row.model_dump())
async def add_api_key(self, api_key: uuid.UUID, user_id: uuid.UUID, lifetime: int) -> None:
row = ApiKey.from_api_key(api_key, user_id, lifetime)
self.kv.setex(row.api_key_hash, row.lifetime, row.user_id)

async def check_api_key(self, api_key: uuid.UUID, user_id: uuid.UUID) -> bool:
table = await self.db.Table("ApiKeys")
row = await table.get_item(Key={"api_key_hash": hash_api_key(api_key)})
if "Item" not in row:
return False
return row["Item"]["user_id"] == str(user_id)
def check_api_key(self, api_key: uuid.UUID, user_id: uuid.UUID) -> bool:
row = self.kv.get(hash_api_key(api_key))
return row is not None and row == user_id

async def delete_api_key(self, api_key: uuid.UUID) -> None:
table = await self.db.Table("ApiKeys")
await table.delete_item(Key={"api_key_hash": hash_api_key(api_key)})
self.kv.delete(hash_api_key(api_key))


async def test_adhoc() -> None:
Expand Down
9 changes: 0 additions & 9 deletions store/app/api/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,6 @@ async def create_tables(crud: Crud | None = None) -> None:
("emailIndex", "email", "S", "HASH"),
],
)
await crud._create_dynamodb_table(
name="ApiKeys",
keys=[
("api_key_hash", "S", "HASH"),
],
gsis=[
("userIdIndex", "user_id", "S", "HASH"),
],
)
await crud._create_dynamodb_table(
name="Robots",
keys=[
Expand Down
5 changes: 3 additions & 2 deletions store/app/api/email.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,17 @@ async def send_email(subject: str, body: str, to: str) -> None:
@dataclass
class OneTimePassPayload:
email: str
lifetime: str

def encode(self) -> str:
expire_minutes = settings.crypto.expire_otp_minutes
expire_after = datetime.timedelta(minutes=expire_minutes)
return encode_jwt({"email": self.email}, expire_after=expire_after)
return encode_jwt({"email": self.email, "lifetime": self.lifetime}, expire_after=expire_after)

@classmethod
def decode(cls, payload: str) -> "OneTimePassPayload":
data = decode_jwt(payload)
return cls(email=data["email"])
return cls(email=data["email"], lifetime=data["lifetime"])


async def send_otp_email(payload: OneTimePassPayload, login_url: str) -> None:
Expand Down
9 changes: 4 additions & 5 deletions store/app/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
expects (for example, converting a UUID into a string).
"""

import datetime
import uuid
from dataclasses import field
from decimal import Decimal

from pydantic import BaseModel
Expand All @@ -27,15 +25,16 @@ def to_uuid(self) -> uuid.UUID:
return uuid.UUID(self.user_id)


# Stored in Redis rather than DynamoDB
codekansas marked this conversation as resolved.
Show resolved Hide resolved
class ApiKey(BaseModel):
api_key_hash: str # Primary key
user_id: str
issued: Decimal = field(default_factory=lambda: Decimal(datetime.datetime.now().timestamp()))
lifetime: int

@classmethod
def from_api_key(cls, api_key: uuid.UUID, user_id: uuid.UUID) -> "ApiKey":
def from_api_key(cls, api_key: uuid.UUID, user_id: uuid.UUID, lifetime: int) -> "ApiKey":
api_key_hash = hash_api_key(api_key)
return cls(api_key_hash=api_key_hash, user_id=str(user_id))
return cls(api_key_hash=api_key_hash, user_id=str(user_id), lifetime=lifetime)


class PurchaseLink(BaseModel):
Expand Down
10 changes: 6 additions & 4 deletions store/app/api/routers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ async def get_api_key(request: Request) -> ApiKeyData:
class UserSignup(BaseModel):
email: str
login_url: str
lifetime: int


def validate_email(email: str) -> str:
Expand All @@ -76,7 +77,7 @@ async def login_user_endpoint(data: UserSignup) -> bool:
True if the email was sent successfully.
"""
email = validate_email(data.email)
payload = OneTimePassPayload(email)
payload = OneTimePassPayload(email, lifetime=data.lifetime)
await send_otp_email(payload, data.login_url)
return True

Expand All @@ -89,7 +90,7 @@ class UserLoginResponse(BaseModel):
api_key: str


async def get_login_response(email: str, crud: Crud) -> UserLoginResponse:
async def get_login_response(email: str, lifetime: int, crud: Crud) -> UserLoginResponse:
"""Takes the user email and returns an API key.

This function gets a user API key for an email which has been validated,
Expand All @@ -98,6 +99,7 @@ async def get_login_response(email: str, crud: Crud) -> UserLoginResponse:
Args:
email: The validated email of the user.
crud: The database CRUD object.
lifetime: The lifetime (in seconds) of the API key to be returned.

Returns:
The API key for the user.
Expand All @@ -112,7 +114,7 @@ async def get_login_response(email: str, crud: Crud) -> UserLoginResponse:
# Issue a new API key for the user.
user_id: uuid.UUID = user_obj.to_uuid()
api_key: uuid.UUID = get_new_api_key(user_id)
await crud.add_api_key(api_key, user_id)
await crud.add_api_key(api_key, user_id, lifetime)

return UserLoginResponse(api_key=str(api_key))

Expand All @@ -132,7 +134,7 @@ async def otp_endpoint(
The API key if the one-time password is valid.
"""
payload = OneTimePassPayload.decode(data.payload)
return await get_login_response(payload.email, crud)
return await get_login_response(payload.email, payload.lifetime, crud)


async def get_google_user_info(token: str) -> dict:
Expand Down
5 changes: 5 additions & 0 deletions store/settings/configs/local.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,8 @@ email:
email: ${oc.env:ROBOLIST_SMTP_EMAIL}
password: ${oc.env:ROBOLIST_SMTP_PASSWORD}
name: ${oc.env:ROBOLIST_SMTP_NAME}
redis:
host: ${oc.env:ROBOLIST_REDIS_HOST,localhost}
port: ${oc.env:ROBOLIST_REDIS_PORT,6379}
password: ${oc.env:ROBOLIST_REDIS_PASSWORD,}
db: ${oc.env:ROBOLIST_REDIS_DB,0}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after thinking about this, let's remove these from the .yaml files and instead set the defaults in the environment.py file

e.g.

from omegaconf import II

@dataclass
class RedisSettings:
    host: str = field(default=II("ROBOLIST_REDIS_HOST"))
    password: str = field(default=II("ROBOLIST_REDIS_PASSWORD"))
    port: int = field(default=6379)
    db: int = field(default=0)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we do the same thing for email?

Copy link
Member

@codekansas codekansas Jun 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes - actually, just did this: #43

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think putting it here makes life less annoying perhaps (local.yaml is checked into git, and this prevents people from needing to set up env vars for local dev that will realistically always be the same)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm since its already merged

9 changes: 9 additions & 0 deletions store/settings/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
from omegaconf import MISSING


@dataclass
class RedisSettings:
host: str = field(default=MISSING)
password: str = field(default=MISSING)
port: int = field(default=6379)
db: int = field(default=0)


@dataclass
class CryptoSettings:
expire_token_minutes: int = field(default=10)
Expand Down Expand Up @@ -37,6 +45,7 @@ class SiteSettings:

@dataclass
class EnvironmentSettings:
redis: RedisSettings = field(default_factory=RedisSettings)
user: UserSettings = field(default_factory=UserSettings)
crypto: CryptoSettings = field(default_factory=CryptoSettings)
email: EmailSettings = field(default_factory=EmailSettings)
Expand Down
4 changes: 2 additions & 2 deletions tests/api/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ def test_user_auth_functions(app_client: TestClient, mock_send_email: MockType)
login_url = "/"

# Sends the one-time password to the test email.
response = app_client.post("/api/users/login", json={"email": test_email, "login_url": login_url})
response = app_client.post("/api/users/login", json={"email": test_email, "login_url": login_url, "lifetime": 3600})
codekansas marked this conversation as resolved.
Show resolved Hide resolved
assert response.status_code == 200, response.json()
assert mock_send_email.call_count == 1

# Uses the one-time pass to get an API key. We need to make a new OTP
# manually because we can't send emails in unit tests.
otp = OneTimePassPayload(email=test_email)
otp = OneTimePassPayload(email=test_email, lifetime=3600)
response = app_client.post("/api/users/otp", json={"payload": otp.encode()})
assert response.status_code == 200, response.json()
response_data = response.json()
Expand Down
Loading