Skip to content

Commit

Permalink
Cloudfront Testing (#590)
Browse files Browse the repository at this point in the history
* Cloudfront Testing

* Unit test

* Added env variable
  • Loading branch information
ivntsng authored Nov 12, 2024
1 parent 3aee0ef commit 4255cae
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 5 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ env:
AWS_SECRET_ACCESS_KEY: test
AWS_ENDPOINT_URL_DYNAMODB: http://localhost:8000
AWS_REGION: us-east-1
CLOUDFRONT_KEY_ID: ${{ secrets.CLOUDFRONT_KEY_ID }}
CLOUDFRONT_PRIVATE_KEY: ${{ secrets.CLOUDFRONT_PRIVATE_KEY }}
CLOUDFRONT_DOMAIN: ${{ secrets.CLOUDFRONT_DOMAIN }}
CLOUDFRONT_PRIVATE_KEY_PATH: /tmp/dummy_key.pem
ONSHAPE_API: ${{ secrets.ONSHAPE_API }}
ONSHAPE_ACCESS_KEY: ${{ secrets.ONSHAPE_ACCESS_KEY }}
ONSHAPE_SECRET_KEY: ${{ secrets.ONSHAPE_SECRET_KEY }}
Expand Down
26 changes: 22 additions & 4 deletions store/app/routers/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_artifact_urls,
)
from store.app.security.user import get_session_user_with_write_permission, maybe_get_user_from_api_key
from store.app.utils.cloudfront_signer import CloudFrontUrlSigner
from store.settings import settings

router = APIRouter()
Expand Down Expand Up @@ -57,16 +58,33 @@ async def artifact_url(
_, file_extension = os.path.splitext(name)
s3_filename = f"{artifact.id}{file_extension}"

# TODO: Use CloudFront API to return a signed CloudFront URL.
return RedirectResponse(
url=get_artifact_url(
# Initialize CloudFront signer
signer = CloudFrontUrlSigner(
key_id=settings.cloudfront.key_id,
private_key_path=settings.cloudfront.private_key_path,
)

# Generate base URL based on environment
if settings.environment == "local":
base_url = get_artifact_url(
artifact_type=artifact.artifact_type,
artifact_id=artifact.id,
listing_id=listing_id,
name=s3_filename,
size=size,
)
)
else:
# For production, use CloudFront domain
base_url = f"https://{settings.cloudfront.domain}/{artifact.artifact_type}/{listing_id}/{s3_filename}"
if size and artifact.artifact_type == "image":
base_url = f"{base_url}_{size}"

# Create and sign URL for production environment
if settings.environment != "local":
policy = signer.create_custom_policy(url=base_url, expire_days=1 / 24) # 1 hour expiration
base_url = signer.generate_presigned_url(base_url, policy=policy)

return RedirectResponse(url=base_url)


class ArtifactUrls(BaseModel):
Expand Down
83 changes: 83 additions & 0 deletions store/app/utils/cloudfront_signer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""This module provides a class to generate signed URLs for AWS CloudFront using RSA keys.
The `CloudFrontUrlSigner` class allows you to create and sign CloudFront URLs with optional custom policies.
"""

import json
from datetime import datetime, timedelta
from typing import Any, Dict, Optional

from botocore.signers import CloudFrontSigner
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey


class CloudFrontUrlSigner:
"""A class to generate signed URLs for AWS CloudFront using RSA keys."""

def __init__(self, key_id: str, private_key_path: str) -> None:
"""Initialize the CloudFrontUrlSigner with a key ID and the path to the private key file.
:param key_id: The CloudFront key ID associated with the public key in your CloudFront key group.
:param private_key_path: The path to the private key PEM file.
"""
self.key_id = key_id
self.private_key_path = private_key_path
self.cf_signer = CloudFrontSigner(key_id, self._rsa_signer)

def _rsa_signer(self, message: bytes) -> bytes:
"""RSA signer function that signs a message using the private key.
Args:
message: The message to be signed.
Returns:
bytes: The RSA signature of the message.
Raises:
ValueError: If the loaded key is not an RSA private key.
"""
with open(self.private_key_path, "rb") as key_file:
private_key = serialization.load_pem_private_key(
key_file.read(),
password=None,
)

if not isinstance(private_key, RSAPrivateKey):
raise ValueError("The provided key is not an RSA private key")

return private_key.sign(message, padding.PKCS1v15(), hashes.SHA1()) # CloudFront requires SHA-1

def generate_presigned_url(self, url: str, policy: Optional[str] = None) -> str:
"""Generate a presigned URL for CloudFront using an optional custom policy.
:param url: The URL to sign.
:param policy: (Optional) A custom policy for the URL.
:return: The signed URL.
"""
return self.cf_signer.generate_presigned_url(url, policy=policy)

def create_custom_policy(self, url: str, expire_days: float = 1, ip_range: Optional[str] = None) -> str:
"""Create a custom policy for CloudFront signed URLs.
:param url: The URL to be signed.
:param expire_days: Number of days until the policy expires (can be fractional, e.g., 1/24 for one hour).
:param ip_range: Optional IP range to restrict access (e.g., "203.0.113.0/24").
:return: The custom policy in JSON format.
"""
expiration_time = int((datetime.utcnow() + timedelta(days=expire_days)).timestamp())
policy: Dict[str, Any] = {
"Statement": [
{
"Resource": url,
"Condition": {
"DateLessThan": {"AWS:EpochTime": expiration_time},
},
}
]
}
if ip_range:
policy["Statement"][0]["Condition"]["IpAddress"] = {"AWS:SourceIp": ip_range}

return json.dumps(policy, separators=(",", ":"))
2 changes: 1 addition & 1 deletion store/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ aioboto3
argon2-cffi
pyjwt[asyncio]
bcrypt

cryptography
# FastAPI dependencies.
aiosmtplib
fastapi[standard]
Expand Down
9 changes: 9 additions & 0 deletions store/settings/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ class StripeSettings:
webhook_secret: str = field(default=II("oc.env:STRIPE_WEBHOOK_SECRET"))


@dataclass
class CloudFrontSettings:
domain: str = field(default=II("oc.env:CLOUDFRONT_DOMAIN"))
key_id: str = field(default=II("oc.env:CLOUDFRONT_KEY_ID"))
private_key_path: str = field(default=II("oc.env:CLOUDFRONT_PRIVATE_KEY_PATH"))


@dataclass
class EnvironmentSettings:
oauth: OauthSettings = field(default_factory=OauthSettings)
Expand All @@ -81,5 +88,7 @@ class EnvironmentSettings:
s3: S3Settings = field(default_factory=S3Settings)
dynamo: DynamoSettings = field(default_factory=DynamoSettings)
site: SiteSettings = field(default_factory=SiteSettings)
cloudfront: CloudFrontSettings = field(default_factory=CloudFrontSettings)
debug: bool = field(default=False)
stripe: StripeSettings = field(default_factory=StripeSettings)
environment: str = field(default="local")

0 comments on commit 4255cae

Please sign in to comment.