Skip to content

Commit

Permalink
#39 Handle SES payload invocations
Browse files Browse the repository at this point in the history
  • Loading branch information
danial-k committed Apr 26, 2024
1 parent 0f52e2d commit ab9d31e
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 31 deletions.
95 changes: 64 additions & 31 deletions lambda/app_handler/provider/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self, payload):
self.payload = payload
self.content = payload
self.has_error = None
self.matched = False
self.parse(payload)

def parse(self, payload):
Expand All @@ -25,8 +26,13 @@ def parse(self, payload):
- Invocation as a Python function
- Invocation of Lambda function
- Invocation behind an AWS HTTP v2 API Gateway
- Invocation by an SNS topic
"""

# Check if lambda was invoked as an SNS message
if self.is_sns_message(payload):
return

if 'body' not in payload:
logging.debug('Body not in payload, returning payload')
return
Expand Down Expand Up @@ -62,39 +68,11 @@ def parse(self, payload):
content_type = content_type.lower()
logging.debug('Content-Type detected: %s', content_type)

matched = False

if content_type.startswith('application/x-www-form-urlencoded'):
logging.debug('Decoding URL encoded form')
try:
html_decoded = urllib.parse.unquote(self.content)
# Assume each value should only appear once, drop repeat keys
self.content = dict(urllib.parse.parse_qsl(html_decoded))
logging.debug(self.content)
matched = True
except(
AttributeError,
ValueError
) as exception:
logging.critical('Error decoding URL encoded form')
logging.critical(exception)
self.has_error = True
self.is_form_url_encoded(content_type=content_type)

if content_type.startswith('application/json'):
logging.debug('Loading JSON string')
try:
self.content = json.loads(self.content, strict=False)
logging.debug(self.content)
matched = True
except (
JSONDecodeError,
TypeError
) as exception:
logging.critical('Error loading string as JSON')
logging.critical(exception)
self.has_error = True
self.is_application_json(content_type=content_type)

if not matched:
if not self.matched:
logging.critical('Error determining how to load content type.')
self.has_error = True

Expand All @@ -119,3 +97,58 @@ def get_remote_ip(self):
logging.debug('Fetched remote IP from payload: %s', remote_ip)

return remote_ip


def is_sns_message(self, payload):
"""
Determine if payload is an SNS message
"""
# Handle first SNS message
if 'Records' in payload and isinstance(payload['Records'], list):
record = payload['Records'][0]
if 'Sns' in record and 'Message' in record['Sns'] and 'Subject' in record['Sns']:
logging.debug('SNS payload extracted')
self.content = record['Sns']
return True

return False


def is_application_json(self, content_type: str):
"""
Determine of payload is JSON
"""
if content_type.startswith('application/json'):
logging.debug('Loading JSON string')
try:
self.content = json.loads(self.content, strict=False)
logging.debug(self.content)
self.matched = True
except (
JSONDecodeError,
TypeError
) as exception:
logging.critical('Error loading string as JSON')
logging.critical(exception)
self.has_error = True


def is_form_url_encoded(self, content_type: str):
"""
Determine if payload is form URL encoded
"""
if content_type.startswith('application/x-www-form-urlencoded'):
logging.debug('Decoding URL encoded form')
try:
html_decoded = urllib.parse.unquote(self.content)
# Assume each value should only appear once, drop repeat keys
self.content = dict(urllib.parse.parse_qsl(html_decoded))
logging.debug(self.content)
self.matched = True
except(
AttributeError,
ValueError
) as exception:
logging.critical('Error decoding URL encoded form')
logging.critical(exception)
self.has_error = True
31 changes: 31 additions & 0 deletions lambda/tests/unit/fixtures/sns_message_v1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
"Records": [
{
"EventVersion": "1.0",
"EventSubscriptionArn": "arn:aws:sns:us-east-1:123456789012:sns-lambda:21be56ed-a058-49f5-8c98-aedd2564c486",
"EventSource": "aws:sns",
"Sns": {
"SignatureVersion": "1",
"Timestamp": "2019-01-02T12:45:07.000Z",
"Signature": "tcc6faL2yUC6dgZdmrwh1Y4cGa/ebXEkAi6RibDsvpi+tE/1+82j...65r==",
"SigningCertURL": "https://sns.us-east-1.amazonaws.com/SimpleNotificationService-ac565b8b1a6c5d002d285f9598aa1d9b.pem",
"MessageId": "95df01b4-ee98-5cb9-9903-4c221d41eb5e",
"Message": "Hello from SNS!",
"MessageAttributes": {
"Test": {
"Type": "String",
"Value": "TestString"
},
"TestBinary": {
"Type": "Binary",
"Value": "TestBinary"
}
},
"Type": "Notification",
"UnsubscribeURL": "https://sns.us-east-1.amazonaws.com/?Action=Unsubscribe&SubscriptionArn=arn:aws:sns:us-east-1:123456789012:test-lambda:21be56ed-a058-49f5-8c98-aedd2564c486",
"TopicArn":"arn:aws:sns:us-east-1:123456789012:sns-lambda",
"Subject": "TestInvoke"
}
}
]
}
11 changes: 11 additions & 0 deletions lambda/tests/unit/provider/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,14 @@ def test_get_remote_ip():
eventv2 = get_json_fixture_file('httpapiv2_gateway_request_urlencoded_base64.json')
assert RequestProvider(eventv1).get_remote_ip() == '127.0.0.1'
assert RequestProvider(eventv2).get_remote_ip() == '127.0.0.1'

# SNS topic

def test_payload_parse_sns_message():
"""
Ensure that an SNS message is correctly parsed
"""

sns_request = get_json_fixture_file('sns_message_v1.json')
assert RequestProvider(sns_request).content['Message'] == 'Hello from SNS!'
assert RequestProvider(sns_request).content['Subject'] == 'TestInvoke'

0 comments on commit ab9d31e

Please sign in to comment.