Skip to content

Commit

Permalink
Convert to Redis and add lifetime for API keys (#42)
Browse files Browse the repository at this point in the history
* Convert to Redis and add lifetime for API keys

* env vars

* lints

* fix some fxns to use redis

* various fixes

* env vars

* lints

* interpolation keys for redis

* fix some google stuff and pass lints

* add redis env vars to tests/conftest.py

* fix env vars for redis

* add fake redis to test

* fix stupid var assignment

* delete useless comment

* change sync redis to async redis and magically make test work

* add env variables to redis fixture

* document magic number 604800
  • Loading branch information
chennisden authored Jun 4, 2024
1 parent 5af125d commit 89be7b6
Show file tree
Hide file tree
Showing 11 changed files with 58 additions and 40 deletions.
11 changes: 11 additions & 0 deletions store/app/api/crud/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

import aioboto3
from botocore.exceptions import ClientError
from redis.asyncio import Redis
from types_aiobotocore_dynamodb.service_resource import DynamoDBServiceResource

from store.settings import settings

logger = logging.getLogger(__name__)


Expand All @@ -16,6 +19,7 @@ def __init__(self) -> None:
super().__init__()

self.__db: DynamoDBServiceResource | None = None
self.__kv: Redis | None = None

@property
def db(self) -> DynamoDBServiceResource:
Expand All @@ -28,6 +32,13 @@ async def __aenter__(self) -> Self:
db = session.resource("dynamodb")
db = await db.__aenter__()
self.__db = db

self.kv = Redis(
host=settings.redis.host,
password=settings.redis.password,
port=settings.redis.port,
db=settings.redis.db,
)
return self

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # noqa: ANN401
Expand Down
26 changes: 9 additions & 17 deletions store/app/api/crud/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import asyncio
import uuid
import warnings
from typing import cast

from boto3.dynamodb.conditions import Key as KeyCondition

Expand Down Expand Up @@ -37,13 +36,11 @@ async def get_user_from_email(self, email: str) -> User | None:
return user

async def get_user_id_from_api_key(self, api_key: uuid.UUID) -> uuid.UUID | None:
table = await self.db.Table("ApiKeys")
api_key_hash = hash_api_key(api_key)
row = await table.get_item(Key={"api_key_hash": api_key_hash})
if "Item" not in row:
user_id = await self.kv.get(api_key_hash)
if user_id is None:
return None
user_id = cast(str, row["Item"]["user_id"])
return uuid.UUID(user_id)
return uuid.UUID(user_id.decode("utf-8"))

async def delete_user(self, user: User) -> None:
table = await self.db.Table("Users")
Expand All @@ -59,21 +56,16 @@ async def get_user_count(self) -> int:
table = await self.db.Table("Users")
return await table.item_count

async def add_api_key(self, api_key: uuid.UUID, user_id: uuid.UUID) -> None:
row = ApiKey.from_api_key(api_key, user_id)
table = await self.db.Table("ApiKeys")
await table.put_item(Item=row.model_dump())
async def add_api_key(self, api_key: uuid.UUID, user_id: uuid.UUID, lifetime: int) -> None:
row = ApiKey.from_api_key(api_key, user_id, lifetime)
await self.kv.setex(row.api_key_hash, row.lifetime, row.user_id)

async def check_api_key(self, api_key: uuid.UUID, user_id: uuid.UUID) -> bool:
table = await self.db.Table("ApiKeys")
row = await table.get_item(Key={"api_key_hash": hash_api_key(api_key)})
if "Item" not in row:
return False
return row["Item"]["user_id"] == str(user_id)
row = await self.kv.get(hash_api_key(api_key))
return row is not None and row == user_id

async def delete_api_key(self, api_key: uuid.UUID) -> None:
table = await self.db.Table("ApiKeys")
await table.delete_item(Key={"api_key_hash": hash_api_key(api_key)})
await self.kv.delete(hash_api_key(api_key))


async def test_adhoc() -> None:
Expand Down
9 changes: 0 additions & 9 deletions store/app/api/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,6 @@ async def create_tables(crud: Crud | None = None) -> None:
("emailIndex", "email", "S", "HASH"),
],
)
await crud._create_dynamodb_table(
name="ApiKeys",
keys=[
("api_key_hash", "S", "HASH"),
],
gsis=[
("userIdIndex", "user_id", "S", "HASH"),
],
)
await crud._create_dynamodb_table(
name="Robots",
keys=[
Expand Down
5 changes: 3 additions & 2 deletions store/app/api/email.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,17 @@ async def send_email(subject: str, body: str, to: str) -> None:
@dataclass
class OneTimePassPayload:
email: str
lifetime: str

def encode(self) -> str:
expire_minutes = settings.crypto.expire_otp_minutes
expire_after = datetime.timedelta(minutes=expire_minutes)
return encode_jwt({"email": self.email}, expire_after=expire_after)
return encode_jwt({"email": self.email, "lifetime": self.lifetime}, expire_after=expire_after)

@classmethod
def decode(cls, payload: str) -> "OneTimePassPayload":
data = decode_jwt(payload)
return cls(email=data["email"])
return cls(email=data["email"], lifetime=data["lifetime"])


async def send_otp_email(payload: OneTimePassPayload, login_url: str) -> None:
Expand Down
10 changes: 5 additions & 5 deletions store/app/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
expects (for example, converting a UUID into a string).
"""

import datetime
import uuid
from dataclasses import field
from decimal import Decimal

from pydantic import BaseModel
Expand All @@ -28,14 +26,16 @@ def to_uuid(self) -> uuid.UUID:


class ApiKey(BaseModel):
"""Stored in Redis rather than DynamoDB."""

api_key_hash: str # Primary key
user_id: str
issued: Decimal = field(default_factory=lambda: Decimal(datetime.datetime.now().timestamp()))
lifetime: int

@classmethod
def from_api_key(cls, api_key: uuid.UUID, user_id: uuid.UUID) -> "ApiKey":
def from_api_key(cls, api_key: uuid.UUID, user_id: uuid.UUID, lifetime: int) -> "ApiKey":
api_key_hash = hash_api_key(api_key)
return cls(api_key_hash=api_key_hash, user_id=str(user_id))
return cls(api_key_hash=api_key_hash, user_id=str(user_id), lifetime=lifetime)


class PurchaseLink(BaseModel):
Expand Down
13 changes: 8 additions & 5 deletions store/app/api/routers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ async def get_api_key(request: Request) -> ApiKeyData:
class UserSignup(BaseModel):
email: str
login_url: str
lifetime: int


def validate_email(email: str) -> str:
Expand All @@ -76,7 +77,7 @@ async def login_user_endpoint(data: UserSignup) -> bool:
True if the email was sent successfully.
"""
email = validate_email(data.email)
payload = OneTimePassPayload(email)
payload = OneTimePassPayload(email, lifetime=str(data.lifetime))
await send_otp_email(payload, data.login_url)
return True

Expand All @@ -89,7 +90,7 @@ class UserLoginResponse(BaseModel):
api_key: str


async def get_login_response(email: str, crud: Crud) -> UserLoginResponse:
async def get_login_response(email: str, lifetime: int, crud: Crud) -> UserLoginResponse:
"""Takes the user email and returns an API key.
This function gets a user API key for an email which has been validated,
Expand All @@ -98,6 +99,7 @@ async def get_login_response(email: str, crud: Crud) -> UserLoginResponse:
Args:
email: The validated email of the user.
crud: The database CRUD object.
lifetime: The lifetime (in seconds) of the API key to be returned.
Returns:
The API key for the user.
Expand All @@ -112,7 +114,7 @@ async def get_login_response(email: str, crud: Crud) -> UserLoginResponse:
# Issue a new API key for the user.
user_id: uuid.UUID = user_obj.to_uuid()
api_key: uuid.UUID = get_new_api_key(user_id)
await crud.add_api_key(api_key, user_id)
await crud.add_api_key(api_key, user_id, lifetime)

return UserLoginResponse(api_key=str(api_key))

Expand All @@ -132,7 +134,7 @@ async def otp_endpoint(
The API key if the one-time password is valid.
"""
payload = OneTimePassPayload.decode(data.payload)
return await get_login_response(payload.email, crud)
return await get_login_response(payload.email, int(payload.lifetime), crud)


async def get_google_user_info(token: str) -> dict:
Expand All @@ -155,6 +157,7 @@ async def google_login_endpoint(
data: GoogleLogin,
crud: Annotated[Crud, Depends(Crud.get)],
) -> UserLoginResponse:
"""Uses Google OAuth to create an API token that lasts for a week (i.e. 604800 seconds)."""
try:
idinfo = await get_google_user_info(data.token)
email = idinfo["email"]
Expand All @@ -163,7 +166,7 @@ async def google_login_endpoint(
if idinfo.get("email_verified") is not True:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Google email not verified")

return await get_login_response(email, crud)
return await get_login_response(email, 604800, crud)


class UserInfoResponse(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions store/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ruff

# Testing
moto[dynamodb]
fakeredis
pytest
pytest-aiohttp
pytest-aiomoto
Expand Down
1 change: 1 addition & 0 deletions store/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pydantic

# AWS dependencies.
aioboto3
redis

# FastAPI dependencies.
aiohttp
Expand Down
9 changes: 9 additions & 0 deletions store/settings/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
from omegaconf import II, MISSING


@dataclass
class RedisSettings:
host: str = field(default=II("oc.env:ROBOLIST_REDIS_HOST"))
password: str = field(default=II("oc.env:ROBOLIST_REDIS_PASSWORD"))
port: int = field(default=6379)
db: int = field(default=0)


@dataclass
class CryptoSettings:
expire_token_minutes: int = field(default=10)
Expand Down Expand Up @@ -38,6 +46,7 @@ class SiteSettings:

@dataclass
class EnvironmentSettings:
redis: RedisSettings = field(default_factory=RedisSettings)
user: UserSettings = field(default_factory=UserSettings)
crypto: CryptoSettings = field(default_factory=CryptoSettings)
email: EmailSettings = field(default_factory=EmailSettings)
Expand Down
4 changes: 2 additions & 2 deletions tests/api/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ def test_user_auth_functions(app_client: TestClient, mock_send_email: MockType)
login_url = "/"

# Sends the one-time password to the test email.
response = app_client.post("/api/users/login", json={"email": test_email, "login_url": login_url})
response = app_client.post("/api/users/login", json={"email": test_email, "login_url": login_url, "lifetime": 3600})
assert response.status_code == 200, response.json()
assert mock_send_email.call_count == 1

# Uses the one-time pass to get an API key. We need to make a new OTP
# manually because we can't send emails in unit tests.
otp = OneTimePassPayload(email=test_email)
otp = OneTimePassPayload(email=test_email, lifetime=3600)
response = app_client.post("/api/users/otp", json={"payload": otp.encode()})
assert response.status_code == 200, response.json()
response_data = response.json()
Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi.testclient import TestClient
from moto.dynamodb import mock_dynamodb
from moto.server import ThreadedMotoServer
import fakeredis
from pytest_mock.plugin import MockerFixture, MockType

os.environ["ROBOLIST_ENVIRONMENT"] = "local"
Expand Down Expand Up @@ -57,6 +58,14 @@ def mock_aws() -> Generator[None, None, None]:
else:
os.environ[k] = v

@pytest.fixture(autouse=True)
def mock_redis(mocker: MockerFixture) -> None:
os.environ["ROBOLIST_REDIS_HOST"] = "localhost"
os.environ["ROBOLIST_REDIS_PASSWORD"] = ""
os.environ["ROBOLIST_REDIS_PORT"] = "6379"
os.environ["ROBOLIST_REDIS_DB"] = "0"
fake_redis = fakeredis.aioredis.FakeRedis()
mocker.patch("store.app.api.crud.base.Redis", return_value=fake_redis)

@pytest.fixture()
def app_client() -> Generator[TestClient, None, None]:
Expand Down

0 comments on commit 89be7b6

Please sign in to comment.