diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7974c8cf..8cff2362 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -23,6 +23,10 @@ env: AWS_SECRET_ACCESS_KEY: test AWS_ENDPOINT_URL_DYNAMODB: http://localhost:8000 AWS_REGION: us-east-1 + CLOUDFRONT_KEY_ID: ${{ secrets.CLOUDFRONT_KEY_ID }} + CLOUDFRONT_PRIVATE_KEY: ${{ secrets.CLOUDFRONT_PRIVATE_KEY }} + CLOUDFRONT_DOMAIN: ${{ secrets.CLOUDFRONT_DOMAIN }} + CLOUDFRONT_PRIVATE_KEY_PATH: /tmp/dummy_key.pem ONSHAPE_API: ${{ secrets.ONSHAPE_API }} ONSHAPE_ACCESS_KEY: ${{ secrets.ONSHAPE_ACCESS_KEY }} ONSHAPE_SECRET_KEY: ${{ secrets.ONSHAPE_SECRET_KEY }} diff --git a/store/app/routers/artifacts.py b/store/app/routers/artifacts.py index f9c5a540..fd4137b2 100644 --- a/store/app/routers/artifacts.py +++ b/store/app/routers/artifacts.py @@ -25,6 +25,7 @@ get_artifact_urls, ) from store.app.security.user import get_session_user_with_write_permission, maybe_get_user_from_api_key +from store.app.utils.cloudfront_signer import CloudFrontUrlSigner from store.settings import settings router = APIRouter() @@ -57,16 +58,33 @@ async def artifact_url( _, file_extension = os.path.splitext(name) s3_filename = f"{artifact.id}{file_extension}" - # TODO: Use CloudFront API to return a signed CloudFront URL. - return RedirectResponse( - url=get_artifact_url( + # Initialize CloudFront signer + signer = CloudFrontUrlSigner( + key_id=settings.cloudfront.key_id, + private_key_path=settings.cloudfront.private_key_path, + ) + + # Generate base URL based on environment + if settings.environment == "local": + base_url = get_artifact_url( artifact_type=artifact.artifact_type, artifact_id=artifact.id, listing_id=listing_id, name=s3_filename, size=size, ) - ) + else: + # For production, use CloudFront domain + base_url = f"https://{settings.cloudfront.domain}/{artifact.artifact_type}/{listing_id}/{s3_filename}" + if size and artifact.artifact_type == "image": + base_url = f"{base_url}_{size}" + + # Create and sign URL for production environment + if settings.environment != "local": + policy = signer.create_custom_policy(url=base_url, expire_days=1 / 24) # 1 hour expiration + base_url = signer.generate_presigned_url(base_url, policy=policy) + + return RedirectResponse(url=base_url) class ArtifactUrls(BaseModel): diff --git a/store/app/utils/cloudfront_signer.py b/store/app/utils/cloudfront_signer.py new file mode 100644 index 00000000..59429725 --- /dev/null +++ b/store/app/utils/cloudfront_signer.py @@ -0,0 +1,83 @@ +"""This module provides a class to generate signed URLs for AWS CloudFront using RSA keys. + +The `CloudFrontUrlSigner` class allows you to create and sign CloudFront URLs with optional custom policies. +""" + +import json +from datetime import datetime, timedelta +from typing import Any, Dict, Optional + +from botocore.signers import CloudFrontSigner +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey + + +class CloudFrontUrlSigner: + """A class to generate signed URLs for AWS CloudFront using RSA keys.""" + + def __init__(self, key_id: str, private_key_path: str) -> None: + """Initialize the CloudFrontUrlSigner with a key ID and the path to the private key file. + + :param key_id: The CloudFront key ID associated with the public key in your CloudFront key group. + :param private_key_path: The path to the private key PEM file. + """ + self.key_id = key_id + self.private_key_path = private_key_path + self.cf_signer = CloudFrontSigner(key_id, self._rsa_signer) + + def _rsa_signer(self, message: bytes) -> bytes: + """RSA signer function that signs a message using the private key. + + Args: + message: The message to be signed. + + Returns: + bytes: The RSA signature of the message. + + Raises: + ValueError: If the loaded key is not an RSA private key. + """ + with open(self.private_key_path, "rb") as key_file: + private_key = serialization.load_pem_private_key( + key_file.read(), + password=None, + ) + + if not isinstance(private_key, RSAPrivateKey): + raise ValueError("The provided key is not an RSA private key") + + return private_key.sign(message, padding.PKCS1v15(), hashes.SHA1()) # CloudFront requires SHA-1 + + def generate_presigned_url(self, url: str, policy: Optional[str] = None) -> str: + """Generate a presigned URL for CloudFront using an optional custom policy. + + :param url: The URL to sign. + :param policy: (Optional) A custom policy for the URL. + :return: The signed URL. + """ + return self.cf_signer.generate_presigned_url(url, policy=policy) + + def create_custom_policy(self, url: str, expire_days: float = 1, ip_range: Optional[str] = None) -> str: + """Create a custom policy for CloudFront signed URLs. + + :param url: The URL to be signed. + :param expire_days: Number of days until the policy expires (can be fractional, e.g., 1/24 for one hour). + :param ip_range: Optional IP range to restrict access (e.g., "203.0.113.0/24"). + :return: The custom policy in JSON format. + """ + expiration_time = int((datetime.utcnow() + timedelta(days=expire_days)).timestamp()) + policy: Dict[str, Any] = { + "Statement": [ + { + "Resource": url, + "Condition": { + "DateLessThan": {"AWS:EpochTime": expiration_time}, + }, + } + ] + } + if ip_range: + policy["Statement"][0]["Condition"]["IpAddress"] = {"AWS:SourceIp": ip_range} + + return json.dumps(policy, separators=(",", ":")) diff --git a/store/requirements.txt b/store/requirements.txt index b031f421..055237c3 100644 --- a/store/requirements.txt +++ b/store/requirements.txt @@ -14,7 +14,7 @@ aioboto3 argon2-cffi pyjwt[asyncio] bcrypt - +cryptography # FastAPI dependencies. aiosmtplib fastapi[standard] diff --git a/store/settings/environment.py b/store/settings/environment.py index 0d086445..1e534251 100644 --- a/store/settings/environment.py +++ b/store/settings/environment.py @@ -71,6 +71,13 @@ class StripeSettings: webhook_secret: str = field(default=II("oc.env:STRIPE_WEBHOOK_SECRET")) +@dataclass +class CloudFrontSettings: + domain: str = field(default=II("oc.env:CLOUDFRONT_DOMAIN")) + key_id: str = field(default=II("oc.env:CLOUDFRONT_KEY_ID")) + private_key_path: str = field(default=II("oc.env:CLOUDFRONT_PRIVATE_KEY_PATH")) + + @dataclass class EnvironmentSettings: oauth: OauthSettings = field(default_factory=OauthSettings) @@ -81,5 +88,7 @@ class EnvironmentSettings: s3: S3Settings = field(default_factory=S3Settings) dynamo: DynamoSettings = field(default_factory=DynamoSettings) site: SiteSettings = field(default_factory=SiteSettings) + cloudfront: CloudFrontSettings = field(default_factory=CloudFrontSettings) debug: bool = field(default=False) stripe: StripeSettings = field(default_factory=StripeSettings) + environment: str = field(default="local")