Skip to content

Commit

Permalink
Additional testing
Browse files Browse the repository at this point in the history
  • Loading branch information
ivntsng committed Nov 12, 2024
1 parent 085daf5 commit 4cad4fc
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 9 deletions.
41 changes: 34 additions & 7 deletions store/app/routers/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import asyncio
import logging
import os
from typing import Annotated, Self
from datetime import datetime, timedelta
from typing import Annotated, Literal, Self

from boto3.dynamodb.conditions import Key
from fastapi import APIRouter, Depends, HTTPException, UploadFile, status
Expand All @@ -24,7 +25,10 @@
get_artifact_url,
get_artifact_urls,
)
from store.app.security.user import get_session_user_with_write_permission, maybe_get_user_from_api_key
from store.app.security.user import (
get_session_user_with_write_permission,
maybe_get_user_from_api_key,
)
from store.app.utils.cloudfront_signer import CloudFrontUrlSigner
from store.settings import settings

Expand Down Expand Up @@ -81,7 +85,7 @@ async def artifact_url(

# Create and sign URL for production environment
if settings.environment != "local":
policy = signer.create_custom_policy(url=base_url, expire_days=1 / 24) # 1 hour expiration
policy = signer.create_custom_policy(url=base_url, expire_days=180)
base_url = signer.generate_presigned_url(base_url, policy=policy)

return RedirectResponse(url=base_url)
Expand All @@ -90,14 +94,37 @@ async def artifact_url(
class ArtifactUrls(BaseModel):
small: str | None = None
large: str
expires_at: int


def get_artifact_url_response(artifact: Artifact) -> ArtifactUrls:
artifact_urls = get_artifact_urls(artifact=artifact)
return ArtifactUrls(
small=artifact_urls.get("small"),
large=artifact_urls["large"],
)
expiration_time = None

# If in production, sign both URLs
if settings.environment != "local":
logger.debug(f"Original URLs for artifact {artifact.id}: {artifact_urls}")

signer = CloudFrontUrlSigner(
key_id=settings.cloudfront.key_id,
private_key_path=settings.cloudfront.private_key_path,
)

expire_days = 180
expiration_time = int((datetime.utcnow() + timedelta(days=expire_days)).timestamp())

# Explicitly iterate over the literal types
sizes: list[Literal["small", "large"]] = ["small", "large"]
for size in sizes:
try:
url = artifact_urls[size]
cf_url = f"https://{settings.cloudfront.domain}/{url}"
policy = signer.create_custom_policy(url=cf_url, expire_days=expire_days)
artifact_urls[size] = signer.generate_presigned_url(cf_url, policy=policy)
except KeyError:
continue

return ArtifactUrls(small=artifact_urls.get("small"), large=artifact_urls["large"], expires_at=expiration_time or 0)


class SingleArtifactResponse(BaseModel):
Expand Down
15 changes: 14 additions & 1 deletion store/app/routers/listings.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
get_session_user_with_write_permission,
maybe_get_user_from_api_key,
)
from store.settings.environment import EnvironmentSettings

# Create settings instance
settings = EnvironmentSettings()

router = APIRouter()

Expand Down Expand Up @@ -612,7 +616,16 @@ async def get_listing(
listing = await crud.get_listing(listing_id)
if listing is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Listing not found")
return await get_listing_common(listing, user, crud)

response = await get_listing_common(listing, user, crud)

# Verify URLs are signed in production
if settings.environment != "local":
for artifact in response.artifacts:
if not any("Key-Pair-Id=" in url for url in [artifact.urls.small, artifact.urls.large] if url is not None):
logger.error(f"Unsigned URLs found for artifact {artifact.artifact_id} in listing {listing_id}")

return response


@router.get("/{username}/{slug}", response_model=GetListingResponse)
Expand Down
2 changes: 1 addition & 1 deletion store/app/utils/cloudfront_signer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _rsa_signer(self, message: bytes) -> bytes:
if not isinstance(private_key, RSAPrivateKey):
raise ValueError("The provided key is not an RSA private key")

return private_key.sign(message, padding.PKCS1v15(), hashes.SHA1()) # CloudFront requires SHA-1
return private_key.sign(message, padding.PKCS1v15(), hashes.SHA1())

def generate_presigned_url(self, url: str, policy: Optional[str] = None) -> str:
"""Generate a presigned URL for CloudFront using an optional custom policy.
Expand Down

0 comments on commit 4cad4fc

Please sign in to comment.