Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert to Redis and add lifetime for API keys #42

Merged
merged 21 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions store/app/api/crud/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

import aioboto3
from botocore.exceptions import ClientError
from redis.asyncio import Redis
from types_aiobotocore_dynamodb.service_resource import DynamoDBServiceResource

from store.settings import settings

logger = logging.getLogger(__name__)


Expand All @@ -16,6 +19,7 @@ def __init__(self) -> None:
super().__init__()

self.__db: DynamoDBServiceResource | None = None
self.__kv: Redis | None = None

@property
def db(self) -> DynamoDBServiceResource:
Expand All @@ -28,6 +32,13 @@ async def __aenter__(self) -> Self:
db = session.resource("dynamodb")
db = await db.__aenter__()
self.__db = db

self.kv = Redis(
host=settings.redis.host,
password=settings.redis.password,
port=settings.redis.port,
db=settings.redis.db,
)
return self

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # noqa: ANN401
Expand Down
26 changes: 9 additions & 17 deletions store/app/api/crud/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import asyncio
import uuid
import warnings
from typing import cast

from boto3.dynamodb.conditions import Key as KeyCondition

Expand Down Expand Up @@ -37,13 +36,11 @@ async def get_user_from_email(self, email: str) -> User | None:
return user

async def get_user_id_from_api_key(self, api_key: uuid.UUID) -> uuid.UUID | None:
table = await self.db.Table("ApiKeys")
api_key_hash = hash_api_key(api_key)
row = await table.get_item(Key={"api_key_hash": api_key_hash})
if "Item" not in row:
user_id = await self.kv.get(api_key_hash)
if user_id is None:
return None
user_id = cast(str, row["Item"]["user_id"])
return uuid.UUID(user_id)
return uuid.UUID(user_id.decode("utf-8"))

async def delete_user(self, user: User) -> None:
table = await self.db.Table("Users")
Expand All @@ -59,21 +56,16 @@ async def get_user_count(self) -> int:
table = await self.db.Table("Users")
return await table.item_count

async def add_api_key(self, api_key: uuid.UUID, user_id: uuid.UUID) -> None:
row = ApiKey.from_api_key(api_key, user_id)
table = await self.db.Table("ApiKeys")
await table.put_item(Item=row.model_dump())
async def add_api_key(self, api_key: uuid.UUID, user_id: uuid.UUID, lifetime: int) -> None:
row = ApiKey.from_api_key(api_key, user_id, lifetime)
await self.kv.setex(row.api_key_hash, row.lifetime, row.user_id)

async def check_api_key(self, api_key: uuid.UUID, user_id: uuid.UUID) -> bool:
table = await self.db.Table("ApiKeys")
row = await table.get_item(Key={"api_key_hash": hash_api_key(api_key)})
if "Item" not in row:
return False
return row["Item"]["user_id"] == str(user_id)
row = await self.kv.get(hash_api_key(api_key))
return row is not None and row == user_id

async def delete_api_key(self, api_key: uuid.UUID) -> None:
table = await self.db.Table("ApiKeys")
await table.delete_item(Key={"api_key_hash": hash_api_key(api_key)})
await self.kv.delete(hash_api_key(api_key))


async def test_adhoc() -> None:
Expand Down
9 changes: 0 additions & 9 deletions store/app/api/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,6 @@ async def create_tables(crud: Crud | None = None) -> None:
("emailIndex", "email", "S", "HASH"),
],
)
await crud._create_dynamodb_table(
name="ApiKeys",
keys=[
("api_key_hash", "S", "HASH"),
],
gsis=[
("userIdIndex", "user_id", "S", "HASH"),
],
)
await crud._create_dynamodb_table(
name="Robots",
keys=[
Expand Down
5 changes: 3 additions & 2 deletions store/app/api/email.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,17 @@ async def send_email(subject: str, body: str, to: str) -> None:
@dataclass
class OneTimePassPayload:
email: str
lifetime: str

def encode(self) -> str:
expire_minutes = settings.crypto.expire_otp_minutes
expire_after = datetime.timedelta(minutes=expire_minutes)
return encode_jwt({"email": self.email}, expire_after=expire_after)
return encode_jwt({"email": self.email, "lifetime": self.lifetime}, expire_after=expire_after)

@classmethod
def decode(cls, payload: str) -> "OneTimePassPayload":
data = decode_jwt(payload)
return cls(email=data["email"])
return cls(email=data["email"], lifetime=data["lifetime"])


async def send_otp_email(payload: OneTimePassPayload, login_url: str) -> None:
Expand Down
10 changes: 5 additions & 5 deletions store/app/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
expects (for example, converting a UUID into a string).
"""

import datetime
import uuid
from dataclasses import field
from decimal import Decimal

from pydantic import BaseModel
Expand All @@ -28,14 +26,16 @@ def to_uuid(self) -> uuid.UUID:


class ApiKey(BaseModel):
"""Stored in Redis rather than DynamoDB."""

api_key_hash: str # Primary key
user_id: str
issued: Decimal = field(default_factory=lambda: Decimal(datetime.datetime.now().timestamp()))
lifetime: int

@classmethod
def from_api_key(cls, api_key: uuid.UUID, user_id: uuid.UUID) -> "ApiKey":
def from_api_key(cls, api_key: uuid.UUID, user_id: uuid.UUID, lifetime: int) -> "ApiKey":
api_key_hash = hash_api_key(api_key)
return cls(api_key_hash=api_key_hash, user_id=str(user_id))
return cls(api_key_hash=api_key_hash, user_id=str(user_id), lifetime=lifetime)


class PurchaseLink(BaseModel):
Expand Down
12 changes: 7 additions & 5 deletions store/app/api/routers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ async def get_api_key(request: Request) -> ApiKeyData:
class UserSignup(BaseModel):
email: str
login_url: str
lifetime: int


def validate_email(email: str) -> str:
Expand All @@ -76,7 +77,7 @@ async def login_user_endpoint(data: UserSignup) -> bool:
True if the email was sent successfully.
"""
email = validate_email(data.email)
payload = OneTimePassPayload(email)
payload = OneTimePassPayload(email, lifetime=str(data.lifetime))
await send_otp_email(payload, data.login_url)
return True

Expand All @@ -89,7 +90,7 @@ class UserLoginResponse(BaseModel):
api_key: str


async def get_login_response(email: str, crud: Crud) -> UserLoginResponse:
async def get_login_response(email: str, lifetime: int, crud: Crud) -> UserLoginResponse:
"""Takes the user email and returns an API key.

This function gets a user API key for an email which has been validated,
Expand All @@ -98,6 +99,7 @@ async def get_login_response(email: str, crud: Crud) -> UserLoginResponse:
Args:
email: The validated email of the user.
crud: The database CRUD object.
lifetime: The lifetime (in seconds) of the API key to be returned.

Returns:
The API key for the user.
Expand All @@ -112,7 +114,7 @@ async def get_login_response(email: str, crud: Crud) -> UserLoginResponse:
# Issue a new API key for the user.
user_id: uuid.UUID = user_obj.to_uuid()
api_key: uuid.UUID = get_new_api_key(user_id)
await crud.add_api_key(api_key, user_id)
await crud.add_api_key(api_key, user_id, lifetime)

return UserLoginResponse(api_key=str(api_key))

Expand All @@ -132,7 +134,7 @@ async def otp_endpoint(
The API key if the one-time password is valid.
"""
payload = OneTimePassPayload.decode(data.payload)
return await get_login_response(payload.email, crud)
return await get_login_response(payload.email, int(payload.lifetime), crud)


async def get_google_user_info(token: str) -> dict:
Expand Down Expand Up @@ -163,7 +165,7 @@ async def google_login_endpoint(
if idinfo.get("email_verified") is not True:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Google email not verified")

return await get_login_response(email, crud)
return await get_login_response(email, 604800, crud)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

magic number :) make this a comment / add a docstring



class UserInfoResponse(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions store/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ruff

# Testing
moto[dynamodb]
fakeredis
pytest
pytest-aiohttp
pytest-aiomoto
Expand Down
1 change: 1 addition & 0 deletions store/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pydantic

# AWS dependencies.
aioboto3
redis

# FastAPI dependencies.
aiohttp
Expand Down
9 changes: 9 additions & 0 deletions store/settings/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
from omegaconf import II, MISSING


@dataclass
class RedisSettings:
host: str = field(default=II("oc.env:ROBOLIST_REDIS_HOST"))
password: str = field(default=II("oc.env:ROBOLIST_REDIS_PASSWORD"))
port: int = field(default=6379)
db: int = field(default=0)


@dataclass
class CryptoSettings:
expire_token_minutes: int = field(default=10)
Expand Down Expand Up @@ -38,6 +46,7 @@ class SiteSettings:

@dataclass
class EnvironmentSettings:
redis: RedisSettings = field(default_factory=RedisSettings)
user: UserSettings = field(default_factory=UserSettings)
crypto: CryptoSettings = field(default_factory=CryptoSettings)
email: EmailSettings = field(default_factory=EmailSettings)
Expand Down
4 changes: 2 additions & 2 deletions tests/api/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ def test_user_auth_functions(app_client: TestClient, mock_send_email: MockType)
login_url = "/"

# Sends the one-time password to the test email.
response = app_client.post("/api/users/login", json={"email": test_email, "login_url": login_url})
response = app_client.post("/api/users/login", json={"email": test_email, "login_url": login_url, "lifetime": 3600})
codekansas marked this conversation as resolved.
Show resolved Hide resolved
assert response.status_code == 200, response.json()
assert mock_send_email.call_count == 1

# Uses the one-time pass to get an API key. We need to make a new OTP
# manually because we can't send emails in unit tests.
otp = OneTimePassPayload(email=test_email)
otp = OneTimePassPayload(email=test_email, lifetime=3600)
response = app_client.post("/api/users/otp", json={"payload": otp.encode()})
assert response.status_code == 200, response.json()
response_data = response.json()
Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi.testclient import TestClient
from moto.dynamodb import mock_dynamodb
from moto.server import ThreadedMotoServer
import fakeredis
from pytest_mock.plugin import MockerFixture, MockType

os.environ["ROBOLIST_ENVIRONMENT"] = "local"
Expand Down Expand Up @@ -57,6 +58,14 @@ def mock_aws() -> Generator[None, None, None]:
else:
os.environ[k] = v

@pytest.fixture(autouse=True)
def mock_redis(mocker: MockerFixture) -> None:
os.environ["ROBOLIST_REDIS_HOST"] = "localhost"
os.environ["ROBOLIST_REDIS_PASSWORD"] = ""
os.environ["ROBOLIST_REDIS_PORT"] = "6379"
os.environ["ROBOLIST_REDIS_DB"] = "0"
fake_redis = fakeredis.aioredis.FakeRedis()
mocker.patch("store.app.api.crud.base.Redis", return_value=fake_redis)

@pytest.fixture()
def app_client() -> Generator[TestClient, None, None]:
Expand Down
Loading