Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional CF test #595

Merged
merged 3 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ env:
CLOUDFRONT_KEY_ID: ${{ secrets.CLOUDFRONT_KEY_ID }}
CLOUDFRONT_PRIVATE_KEY: ${{ secrets.CLOUDFRONT_PRIVATE_KEY }}
CLOUDFRONT_DOMAIN: ${{ secrets.CLOUDFRONT_DOMAIN }}
CLOUDFRONT_PRIVATE_KEY_PATH: /tmp/dummy_key.pem
ONSHAPE_API: ${{ secrets.ONSHAPE_API }}
ONSHAPE_ACCESS_KEY: ${{ secrets.ONSHAPE_ACCESS_KEY }}
ONSHAPE_SECRET_KEY: ${{ secrets.ONSHAPE_SECRET_KEY }}
Expand Down
44 changes: 19 additions & 25 deletions store/app/routers/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
can_write_listing,
check_content_type,
get_artifact_type,
get_artifact_url,
get_artifact_urls,
)
from store.app.security.user import (
Expand Down Expand Up @@ -65,30 +64,19 @@ async def artifact_url(
# Initialize CloudFront signer
signer = CloudFrontUrlSigner(
key_id=settings.cloudfront.key_id,
private_key_path=settings.cloudfront.private_key_path,
private_key=settings.cloudfront.private_key,
)

# Generate base URL based on environment
if settings.environment == "local":
base_url = get_artifact_url(
artifact_type=artifact.artifact_type,
artifact_id=artifact.id,
listing_id=listing_id,
name=s3_filename,
size=size,
)
else:
# For production, use CloudFront domain
base_url = f"https://{settings.cloudfront.domain}/{artifact.artifact_type}/{listing_id}/{s3_filename}"
if size and artifact.artifact_type == "image":
base_url = f"{base_url}_{size}"
# Always use CloudFront domain and sign the URL
base_url = f"https://{settings.cloudfront.domain}/{artifact.artifact_type}/{listing_id}/{s3_filename}"
if size and artifact.artifact_type == "image":
base_url = f"{base_url}_{size}"

# Create and sign URL for production environment
if settings.environment != "local":
policy = signer.create_custom_policy(url=base_url, expire_days=180)
base_url = signer.generate_presigned_url(base_url, policy=policy)
# Create and sign URL
policy = signer.create_custom_policy(url=base_url, expire_days=180)
signed_url = signer.generate_presigned_url(base_url, policy=policy)

return RedirectResponse(url=base_url)
return RedirectResponse(url=signed_url)


class ArtifactUrls(BaseModel):
Expand All @@ -107,18 +95,24 @@ def get_artifact_url_response(artifact: Artifact) -> ArtifactUrls:

signer = CloudFrontUrlSigner(
key_id=settings.cloudfront.key_id,
private_key_path=settings.cloudfront.private_key_path,
private_key=settings.cloudfront.private_key,
)

expire_days = 180
expiration_time = int((datetime.utcnow() + timedelta(days=expire_days)).timestamp())

# Explicitly iterate over the literal types
sizes: list[Literal["small", "large"]] = ["small", "large"]
for size in sizes:
try:
url = artifact_urls[size]
cf_url = f"https://{settings.cloudfront.domain}/{url}"
cf_url = (
f"https://{settings.cloudfront.domain}/{artifact.artifact_type}/{artifact.listing_id}/{artifact.id}"
)
if size == "small":
cf_url += "_small_256x256"
elif size == "large":
cf_url += "_large_1536x1536"
cf_url += f"_{artifact.name}"

policy = signer.create_custom_policy(url=cf_url, expire_days=expire_days)
artifact_urls[size] = signer.generate_presigned_url(cf_url, policy=policy)
except KeyError:
Expand Down
23 changes: 11 additions & 12 deletions store/app/utils/cloudfront_signer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
class CloudFrontUrlSigner:
"""A class to generate signed URLs for AWS CloudFront using RSA keys."""

def __init__(self, key_id: str, private_key_path: str) -> None:
"""Initialize the CloudFrontUrlSigner with a key ID and the path to the private key file.
def __init__(self, key_id: str, private_key: str) -> None:
"""Initialize the CloudFrontUrlSigner with a key ID and private key content.

:param key_id: The CloudFront key ID associated with the public key in your CloudFront key group.
:param private_key_path: The path to the private key PEM file.
:param private_key: The private key content in PEM format.
"""
self.key_id = key_id
self.private_key_path = private_key_path
self.private_key = private_key
self.cf_signer = CloudFrontSigner(key_id, self._rsa_signer)

def _rsa_signer(self, message: bytes) -> bytes:
Expand All @@ -38,16 +38,15 @@ def _rsa_signer(self, message: bytes) -> bytes:
Raises:
ValueError: If the loaded key is not an RSA private key.
"""
with open(self.private_key_path, "rb") as key_file:
private_key = serialization.load_pem_private_key(
key_file.read(),
password=None,
)
private_key = serialization.load_pem_private_key(
self.private_key.encode("utf-8"),
password=None,
)

if not isinstance(private_key, RSAPrivateKey):
raise ValueError("The provided key is not an RSA private key")
if not isinstance(private_key, RSAPrivateKey):
raise ValueError("The provided key is not an RSA private key")

return private_key.sign(message, padding.PKCS1v15(), hashes.SHA1())
return private_key.sign(message, padding.PKCS1v15(), hashes.SHA1())

def generate_presigned_url(self, url: str, policy: Optional[str] = None) -> str:
"""Generate a presigned URL for CloudFront using an optional custom policy.
Expand Down
2 changes: 1 addition & 1 deletion store/settings/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class StripeSettings:
class CloudFrontSettings:
domain: str = field(default=II("oc.env:CLOUDFRONT_DOMAIN"))
key_id: str = field(default=II("oc.env:CLOUDFRONT_KEY_ID"))
private_key_path: str = field(default=II("oc.env:CLOUDFRONT_PRIVATE_KEY_PATH"))
private_key: str = field(default=II("oc.env:CLOUDFRONT_PRIVATE_KEY"))


@dataclass
Expand Down
Loading