Skip to content

Commit

Permalink
Additional CF test (#595)
Browse files Browse the repository at this point in the history
  • Loading branch information
ivntsng authored Nov 12, 2024
1 parent 19f4efb commit f045e4f
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 39 deletions.
1 change: 0 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ env:
CLOUDFRONT_KEY_ID: ${{ secrets.CLOUDFRONT_KEY_ID }}
CLOUDFRONT_PRIVATE_KEY: ${{ secrets.CLOUDFRONT_PRIVATE_KEY }}
CLOUDFRONT_DOMAIN: ${{ secrets.CLOUDFRONT_DOMAIN }}
CLOUDFRONT_PRIVATE_KEY_PATH: /tmp/dummy_key.pem
ONSHAPE_API: ${{ secrets.ONSHAPE_API }}
ONSHAPE_ACCESS_KEY: ${{ secrets.ONSHAPE_ACCESS_KEY }}
ONSHAPE_SECRET_KEY: ${{ secrets.ONSHAPE_SECRET_KEY }}
Expand Down
44 changes: 19 additions & 25 deletions store/app/routers/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
can_write_listing,
check_content_type,
get_artifact_type,
get_artifact_url,
get_artifact_urls,
)
from store.app.security.user import (
Expand Down Expand Up @@ -65,30 +64,19 @@ async def artifact_url(
# Initialize CloudFront signer
signer = CloudFrontUrlSigner(
key_id=settings.cloudfront.key_id,
private_key_path=settings.cloudfront.private_key_path,
private_key=settings.cloudfront.private_key,
)

# Generate base URL based on environment
if settings.environment == "local":
base_url = get_artifact_url(
artifact_type=artifact.artifact_type,
artifact_id=artifact.id,
listing_id=listing_id,
name=s3_filename,
size=size,
)
else:
# For production, use CloudFront domain
base_url = f"https://{settings.cloudfront.domain}/{artifact.artifact_type}/{listing_id}/{s3_filename}"
if size and artifact.artifact_type == "image":
base_url = f"{base_url}_{size}"
# Always use CloudFront domain and sign the URL
base_url = f"https://{settings.cloudfront.domain}/{artifact.artifact_type}/{listing_id}/{s3_filename}"
if size and artifact.artifact_type == "image":
base_url = f"{base_url}_{size}"

# Create and sign URL for production environment
if settings.environment != "local":
policy = signer.create_custom_policy(url=base_url, expire_days=180)
base_url = signer.generate_presigned_url(base_url, policy=policy)
# Create and sign URL
policy = signer.create_custom_policy(url=base_url, expire_days=180)
signed_url = signer.generate_presigned_url(base_url, policy=policy)

return RedirectResponse(url=base_url)
return RedirectResponse(url=signed_url)


class ArtifactUrls(BaseModel):
Expand All @@ -107,18 +95,24 @@ def get_artifact_url_response(artifact: Artifact) -> ArtifactUrls:

signer = CloudFrontUrlSigner(
key_id=settings.cloudfront.key_id,
private_key_path=settings.cloudfront.private_key_path,
private_key=settings.cloudfront.private_key,
)

expire_days = 180
expiration_time = int((datetime.utcnow() + timedelta(days=expire_days)).timestamp())

# Explicitly iterate over the literal types
sizes: list[Literal["small", "large"]] = ["small", "large"]
for size in sizes:
try:
url = artifact_urls[size]
cf_url = f"https://{settings.cloudfront.domain}/{url}"
cf_url = (
f"https://{settings.cloudfront.domain}/{artifact.artifact_type}/{artifact.listing_id}/{artifact.id}"
)
if size == "small":
cf_url += "_small_256x256"
elif size == "large":
cf_url += "_large_1536x1536"
cf_url += f"_{artifact.name}"

policy = signer.create_custom_policy(url=cf_url, expire_days=expire_days)
artifact_urls[size] = signer.generate_presigned_url(cf_url, policy=policy)
except KeyError:
Expand Down
23 changes: 11 additions & 12 deletions store/app/utils/cloudfront_signer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
class CloudFrontUrlSigner:
"""A class to generate signed URLs for AWS CloudFront using RSA keys."""

def __init__(self, key_id: str, private_key_path: str) -> None:
"""Initialize the CloudFrontUrlSigner with a key ID and the path to the private key file.
def __init__(self, key_id: str, private_key: str) -> None:
"""Initialize the CloudFrontUrlSigner with a key ID and private key content.
:param key_id: The CloudFront key ID associated with the public key in your CloudFront key group.
:param private_key_path: The path to the private key PEM file.
:param private_key: The private key content in PEM format.
"""
self.key_id = key_id
self.private_key_path = private_key_path
self.private_key = private_key
self.cf_signer = CloudFrontSigner(key_id, self._rsa_signer)

def _rsa_signer(self, message: bytes) -> bytes:
Expand All @@ -38,16 +38,15 @@ def _rsa_signer(self, message: bytes) -> bytes:
Raises:
ValueError: If the loaded key is not an RSA private key.
"""
with open(self.private_key_path, "rb") as key_file:
private_key = serialization.load_pem_private_key(
key_file.read(),
password=None,
)
private_key = serialization.load_pem_private_key(
self.private_key.encode("utf-8"),
password=None,
)

if not isinstance(private_key, RSAPrivateKey):
raise ValueError("The provided key is not an RSA private key")
if not isinstance(private_key, RSAPrivateKey):
raise ValueError("The provided key is not an RSA private key")

return private_key.sign(message, padding.PKCS1v15(), hashes.SHA1())
return private_key.sign(message, padding.PKCS1v15(), hashes.SHA1())

def generate_presigned_url(self, url: str, policy: Optional[str] = None) -> str:
"""Generate a presigned URL for CloudFront using an optional custom policy.
Expand Down
2 changes: 1 addition & 1 deletion store/settings/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class StripeSettings:
class CloudFrontSettings:
domain: str = field(default=II("oc.env:CLOUDFRONT_DOMAIN"))
key_id: str = field(default=II("oc.env:CLOUDFRONT_KEY_ID"))
private_key_path: str = field(default=II("oc.env:CLOUDFRONT_PRIVATE_KEY_PATH"))
private_key: str = field(default=II("oc.env:CLOUDFRONT_PRIVATE_KEY"))


@dataclass
Expand Down

0 comments on commit f045e4f

Please sign in to comment.