diff --git a/aws_lambda_powertools/utilities/parser/envelopes/__init__.py b/aws_lambda_powertools/utilities/parser/envelopes/__init__.py index d5754481ee8..e1ac8cdbf5e 100644 --- a/aws_lambda_powertools/utilities/parser/envelopes/__init__.py +++ b/aws_lambda_powertools/utilities/parser/envelopes/__init__.py @@ -1,4 +1,5 @@ from .apigw import ApiGatewayEnvelope +from .apigw_websocket import ApiGatewayWebSocketEnvelope from .apigwv2 import ApiGatewayV2Envelope from .base import BaseEnvelope from .bedrock_agent import BedrockAgentEnvelope @@ -17,6 +18,7 @@ __all__ = [ "ApiGatewayEnvelope", "ApiGatewayV2Envelope", + "ApiGatewayWebSocketEnvelope", "BedrockAgentEnvelope", "CloudWatchLogsEnvelope", "DynamoDBStreamEnvelope", diff --git a/aws_lambda_powertools/utilities/parser/envelopes/apigw_websocket.py b/aws_lambda_powertools/utilities/parser/envelopes/apigw_websocket.py new file mode 100644 index 00000000000..37d08dec180 --- /dev/null +++ b/aws_lambda_powertools/utilities/parser/envelopes/apigw_websocket.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from aws_lambda_powertools.utilities.parser.envelopes.base import BaseEnvelope +from aws_lambda_powertools.utilities.parser.models import APIGatewayWebSocketMessageEventModel + +if TYPE_CHECKING: + from aws_lambda_powertools.utilities.parser.types import Model + +logger = logging.getLogger(__name__) + + +class ApiGatewayWebSocketEnvelope(BaseEnvelope): + """API Gateway WebSockets envelope to extract data within body key of messages routes + (not disconnect or connect)""" + + def parse(self, data: dict[str, Any] | Any | None, model: type[Model]) -> Model | None: + """Parses data found with model provided + + Parameters + ---------- + data : dict + Lambda event to be parsed + model : type[Model] + Data model provided to parse after extracting data using envelope + + Returns + ------- + Any + Parsed detail payload with model provided + """ + logger.debug( + f"Parsing incoming data with Api Gateway WebSockets model {APIGatewayWebSocketMessageEventModel}", + ) + parsed_envelope: APIGatewayWebSocketMessageEventModel = APIGatewayWebSocketMessageEventModel.model_validate( + data, + ) + logger.debug(f"Parsing event payload in `detail` with {model}") + return self._parse(data=parsed_envelope.body, model=model) diff --git a/aws_lambda_powertools/utilities/parser/models/__init__.py b/aws_lambda_powertools/utilities/parser/models/__init__.py index ea166cd0a0a..7c409ef6b83 100644 --- a/aws_lambda_powertools/utilities/parser/models/__init__.py +++ b/aws_lambda_powertools/utilities/parser/models/__init__.py @@ -7,6 +7,16 @@ APIGatewayEventRequestContext, APIGatewayProxyEventModel, ) +from .apigw_websocket import ( + APIGatewayWebSocketConnectEventModel, + APIGatewayWebSocketConnectEventRequestContext, + APIGatewayWebSocketDisconnectEventModel, + APIGatewayWebSocketDisconnectEventRequestContext, + APIGatewayWebSocketEventIdentity, + APIGatewayWebSocketEventRequestContextBase, + APIGatewayWebSocketMessageEventModel, + APIGatewayWebSocketMessageEventRequestContext, +) from .apigwv2 import ( ApiGatewayAuthorizerRequestV2, APIGatewayProxyEventV2Model, @@ -105,6 +115,14 @@ __all__ = [ "APIGatewayProxyEventV2Model", "ApiGatewayAuthorizerRequestV2", + "APIGatewayWebSocketEventIdentity", + "APIGatewayWebSocketMessageEventModel", + "APIGatewayWebSocketMessageEventRequestContext", + "APIGatewayWebSocketConnectEventModel", + "APIGatewayWebSocketConnectEventRequestContext", + "APIGatewayWebSocketDisconnectEventRequestContext", + "APIGatewayWebSocketDisconnectEventModel", + "APIGatewayWebSocketEventRequestContextBase", "RequestContextV2", "RequestContextV2Http", "RequestContextV2Authorizer", diff --git a/aws_lambda_powertools/utilities/parser/models/apigw_websocket.py b/aws_lambda_powertools/utilities/parser/models/apigw_websocket.py new file mode 100644 index 00000000000..0655825e776 --- /dev/null +++ b/aws_lambda_powertools/utilities/parser/models/apigw_websocket.py @@ -0,0 +1,63 @@ +from datetime import datetime +from typing import Dict, List, Literal, Optional, Type, Union + +from pydantic import BaseModel, Field +from pydantic.networks import IPvAnyNetwork + + +class APIGatewayWebSocketEventIdentity(BaseModel): + source_ip: IPvAnyNetwork = Field(alias="sourceIp") + user_agent: Optional[str] = Field(None, alias="userAgent") + +class APIGatewayWebSocketEventRequestContextBase(BaseModel): + extended_request_id: str = Field(alias="extendedRequestId") + request_time: str = Field(alias="requestTime") + stage: str = Field(alias="stage") + connected_at: datetime = Field(alias="connectedAt") + request_time_epoch: datetime = Field(alias="requestTimeEpoch") + identity: APIGatewayWebSocketEventIdentity = Field(alias="identity") + request_id: str = Field(alias="requestId") + domain_name: str = Field(alias="domainName") + connection_id: str = Field(alias="connectionId") + api_id: str = Field(alias="apiId") + + +class APIGatewayWebSocketMessageEventRequestContext(APIGatewayWebSocketEventRequestContextBase): + route_key: str = Field(alias="routeKey") + message_id: str = Field(alias="messageId") + event_type: Literal["MESSAGE"] = Field(alias="eventType") + message_direction: Literal["IN", "OUT"] = Field(alias="messageDirection") + + +class APIGatewayWebSocketConnectEventRequestContext(APIGatewayWebSocketEventRequestContextBase): + route_key: Literal["$connect"] = Field(alias="routeKey") + event_type: Literal["CONNECT"] = Field(alias="eventType") + message_direction: Literal["IN"] = Field(alias="messageDirection") + + +class APIGatewayWebSocketDisconnectEventRequestContext(APIGatewayWebSocketEventRequestContextBase): + route_key: Literal["$disconnect"] = Field(alias="routeKey") + disconnect_status_code: int = Field(alias="disconnectStatusCode") + event_type: Literal["DISCONNECT"] = Field(alias="eventType") + message_direction: Literal["IN"] = Field(alias="messageDirection") + disconnect_reason: str = Field(alias="disconnectReason") + + +class APIGatewayWebSocketConnectEventModel(BaseModel): + headers: Dict[str, str] = Field(alias="headers") + multi_value_headers: Dict[str, List[str]] = Field(alias="multiValueHeaders") + request_context: APIGatewayWebSocketConnectEventRequestContext = Field(alias="requestContext") + is_base64_encoded: bool = Field(alias="isBase64Encoded") + + +class APIGatewayWebSocketDisconnectEventModel(BaseModel): + headers: Dict[str, str] = Field(alias="headers") + multi_value_headers: Dict[str, List[str]] = Field(alias="multiValueHeaders") + request_context: APIGatewayWebSocketDisconnectEventRequestContext = Field(alias="requestContext") + is_base64_encoded: bool = Field(alias="isBase64Encoded") + + +class APIGatewayWebSocketMessageEventModel(BaseModel): + request_context: APIGatewayWebSocketMessageEventRequestContext = Field(alias="requestContext") + is_base64_encoded: bool = Field(alias="isBase64Encoded") + body: Optional[Union[str, Type[BaseModel]]] = Field(None, alias="body") diff --git a/docs/utilities/parser.md b/docs/utilities/parser.md index 4c86c983d31..4cf11a32769 100644 --- a/docs/utilities/parser.md +++ b/docs/utilities/parser.md @@ -108,6 +108,9 @@ The example above uses `SqsModel`. Other built-in models can be found below. | **ApiGatewayAuthorizerRequest** | Lambda Event Source payload for Amazon API Gateway Lambda Authorizer with Request | | **APIGatewayProxyEventV2Model** | Lambda Event Source payload for Amazon API Gateway v2 payload | | **ApiGatewayAuthorizerRequestV2** | Lambda Event Source payload for Amazon API Gateway v2 Lambda Authorizer | +| **APIGatewayWebSocketMessageEventModel** | Lambda Event Source payload for Amazon API Gateway WebSocket API message body | +| **APIGatewayWebSocketConnectEventModel** | Lambda Event Source payload for Amazon API Gateway WebSocket API $connect message | +| **APIGatewayWebSocketDisconnectEventModel** | Lambda Event Source payload for Amazon API Gateway WebSocket API $disconnect message | | **BedrockAgentEventModel** | Lambda Event Source payload for Bedrock Agents | | **CloudFormationCustomResourceCreateModel** | Lambda Event Source payload for AWS CloudFormation `CREATE` operation | | **CloudFormationCustomResourceUpdateModel** | Lambda Event Source payload for AWS CloudFormation `UPDATE` operation | @@ -188,8 +191,9 @@ You can use pre-built envelopes provided by the Parser to extract and parse spec | **KinesisFirehoseEnvelope** | 1. Parses data using `KinesisFirehoseModel` which will base64 decode it. ``2. Parses records in in` Records` key using your model`` and returns them in a list. | `List[Model]` | | **SnsEnvelope** | 1. Parses data using `SnsModel`. ``2. Parses records in `body` key using your model`` and return them in a list. | `List[Model]` | | **SnsSqsEnvelope** | 1. Parses data using `SqsModel`. `` 2. Parses SNS records in `body` key using `SnsNotificationModel`. `` 3. Parses data in `Message` key using your model and return them in a list. | `List[Model]` | -| **ApiGatewayEnvelope** | 1. Parses data using `APIGatewayProxyEventModel`. ``2. Parses `body` key using your model`` and returns it. | `Model` | | **ApiGatewayV2Envelope** | 1. Parses data using `APIGatewayProxyEventV2Model`. ``2. Parses `body` key using your model`` and returns it. | `Model` | +| **ApiGatewayEnvelope** | 1. Parses data using `APIGatewayProxyEventModel`. ``2. Parses `body` key using your model`` and returns it. | `Model` | +| **ApiGatewayWebSocketEnvelope** | 1. Parses data using `APIGatewayWebSocketMessageEventModel`. ``2. Parses `body` key using your model`` and returns it. | `Model` | | **LambdaFunctionUrlEnvelope** | 1. Parses data using `LambdaFunctionUrlModel`. ``2. Parses `body` key using your model`` and returns it. | `Model` | | **KafkaEnvelope** | 1. Parses data using `KafkaRecordModel`. ``2. Parses `value` key using your model`` and returns it. | `Model` | | **VpcLatticeEnvelope** | 1. Parses data using `VpcLatticeModel`. ``2. Parses `value` key using your model`` and returns it. | `Model` | diff --git a/tests/events/apiGatewayWebSocketApiConnect.json b/tests/events/apiGatewayWebSocketApiConnect.json new file mode 100644 index 00000000000..27f8794c9bd --- /dev/null +++ b/tests/events/apiGatewayWebSocketApiConnect.json @@ -0,0 +1,40 @@ +{ + "headers": { + "Host": "fjnq7njcv2.execute-api.us-east-1.amazonaws.com", + "Sec-WebSocket-Extensions": "permessage-deflate; client_max_window_bits", + "Sec-WebSocket-Key": "+W5xw47OHh3OTFsWKjGu9Q==", + "Sec-WebSocket-Version": "13", + "X-Amzn-Trace-Id": "Root=1-6731ebfc-08e1e656421db73c5d2eef31", + "X-Forwarded-For": "166.90.225.1", + "X-Forwarded-Port": "443", + "X-Forwarded-Proto": "https" + }, + "multiValueHeaders": { + "Host": ["fjnq7njcv2.execute-api.us-east-1.amazonaws.com"], + "Sec-WebSocket-Extensions": ["permessage-deflate; client_max_window_bits"], + "Sec-WebSocket-Key": ["+W5xw47OHh3OTFsWKjGu9Q=="], + "Sec-WebSocket-Version": ["13"], + "X-Amzn-Trace-Id": ["Root=1-6731ebfc-08e1e656421db73c5d2eef31"], + "X-Forwarded-For": ["166.90.225.1"], + "X-Forwarded-Port": ["443"], + "X-Forwarded-Proto": ["https"] + }, + "requestContext": { + "routeKey": "$connect", + "eventType": "CONNECT", + "extendedRequestId": "BFHPhFe3IAMF95g=", + "requestTime": "11/Nov/2024:11:35:24 +0000", + "messageDirection": "IN", + "stage": "prod", + "connectedAt": 1731324924553, + "requestTimeEpoch": 1731324924561, + "identity": { + "sourceIp": "166.90.225.1" + }, + "requestId": "BFHPhFe3IAMF95g=", + "domainName": "asasasas.execute-api.us-east-1.amazonaws.com", + "connectionId": "BFHPhfCWIAMCKlQ=", + "apiId": "asasasas" + }, + "isBase64Encoded": false +} \ No newline at end of file diff --git a/tests/events/apiGatewayWebSocketApiDisconnect.json b/tests/events/apiGatewayWebSocketApiDisconnect.json new file mode 100644 index 00000000000..f4624562ef6 --- /dev/null +++ b/tests/events/apiGatewayWebSocketApiDisconnect.json @@ -0,0 +1,34 @@ +{ + "headers": { + "Host": "asasasas.execute-api.us-east-1.amazonaws.com", + "x-api-key": "", + "X-Forwarded-For": "", + "x-restapi": "" + }, + "multiValueHeaders": { + "Host": ["asasasas.execute-api.us-east-1.amazonaws.com"], + "x-api-key": [""], + "X-Forwarded-For": [""], + "x-restapi": [""] + }, + "requestContext": { + "routeKey": "$disconnect", + "disconnectStatusCode": 1005, + "eventType": "DISCONNECT", + "extendedRequestId": "BFbOeE87IAMF31w=", + "requestTime": "11/Nov/2024:13:51:49 +0000", + "messageDirection": "IN", + "disconnectReason": "Client-side close frame status not set", + "stage": "prod", + "connectedAt": 1731332735513, + "requestTimeEpoch": 1731333109875, + "identity": { + "sourceIp": "166.90.225.1" + }, + "requestId": "BFbOeE87IAMF31w=", + "domainName": "asasasas.execute-api.us-east-1.amazonaws.com", + "connectionId": "BFaT_fALIAMCKug=", + "apiId": "asasasas" + }, + "isBase64Encoded": false +} \ No newline at end of file diff --git a/tests/events/apiGatewayWebSocketApiMessage.json b/tests/events/apiGatewayWebSocketApiMessage.json new file mode 100644 index 00000000000..908a713ce20 --- /dev/null +++ b/tests/events/apiGatewayWebSocketApiMessage.json @@ -0,0 +1,22 @@ +{ + "requestContext": { + "routeKey": "chat", + "messageId": "BFaVtfGSIAMCKug=", + "eventType": "MESSAGE", + "extendedRequestId": "BFaVtH2HoAMFZEQ=", + "requestTime": "11/Nov/2024:13:45:46 +0000", + "messageDirection": "IN", + "stage": "prod", + "connectedAt": 1731332735513, + "requestTimeEpoch": 1731332746514, + "identity": { + "sourceIp": "166.90.225.1" + }, + "requestId": "BFaVtH2HoAMFZEQ=", + "domainName": "asasasas.execute-api.us-east-1.amazonaws.com", + "connectionId": "BFaT_fALIAMCKug=", + "apiId": "asasasas" + }, + "body": "{\"action\": \"chat\", \"message\": \"Hello from client\"}", + "isBase64Encoded": false +} \ No newline at end of file diff --git a/tests/unit/parser/_pydantic/schemas.py b/tests/unit/parser/_pydantic/schemas.py index b4b69135ff9..0713924c486 100644 --- a/tests/unit/parser/_pydantic/schemas.py +++ b/tests/unit/parser/_pydantic/schemas.py @@ -87,6 +87,11 @@ class MyApiGatewayBusiness(BaseModel): username: str +class MyApiGatewayWebSocketBusiness(BaseModel): + message: str + action: str + + class MyALambdaFuncUrlBusiness(BaseModel): message: str username: str diff --git a/tests/unit/parser/_pydantic/test_apigw_websockets.py b/tests/unit/parser/_pydantic/test_apigw_websockets.py new file mode 100644 index 00000000000..aea77217d93 --- /dev/null +++ b/tests/unit/parser/_pydantic/test_apigw_websockets.py @@ -0,0 +1,117 @@ +from aws_lambda_powertools.utilities.parser import envelopes, parse +from aws_lambda_powertools.utilities.parser.models import ( + APIGatewayWebSocketConnectEventModel, + APIGatewayWebSocketDisconnectEventModel, + APIGatewayWebSocketMessageEventModel, +) +from tests.functional.utils import load_event +from tests.unit.parser._pydantic.schemas import MyApiGatewayWebSocketBusiness + + +def test_apigw_websocket_message_event_with_envelope(): + raw_event = load_event("apiGatewayWebSocketApiMessage.json") + raw_event["body"] = '{"action": "chat", "message": "Hello Ran"}' + parsed_event: MyApiGatewayWebSocketBusiness = parse( + event=raw_event, + model=MyApiGatewayWebSocketBusiness, + envelope=envelopes.ApiGatewayWebSocketEnvelope, + ) + + assert parsed_event.message == "Hello Ran" + assert parsed_event.action == "chat" + + +def test_apigw_websocket_message_event(): + raw_event = load_event("apiGatewayWebSocketApiMessage.json") + parsed_event: APIGatewayWebSocketMessageEventModel = APIGatewayWebSocketMessageEventModel(**raw_event) + + request_context = parsed_event.request_context + assert request_context.api_id == raw_event["requestContext"]["apiId"] + assert request_context.domain_name == raw_event["requestContext"]["domainName"] + assert request_context.extended_request_id == raw_event["requestContext"]["extendedRequestId"] + + identity = request_context.identity + assert str(identity.source_ip) == f'{raw_event["requestContext"]["identity"]["sourceIp"]}/32' + + assert request_context.request_id == raw_event["requestContext"]["requestId"] + assert request_context.request_time == raw_event["requestContext"]["requestTime"] + convert_time = int(round(request_context.request_time_epoch.timestamp() * 1000)) + assert convert_time == 1731332746514 + assert request_context.stage == raw_event["requestContext"]["stage"] + convert_time = int(round(request_context.connected_at.timestamp() * 1000)) + assert convert_time == 1731332735513 + assert request_context.connection_id == raw_event["requestContext"]["connectionId"] + assert request_context.event_type == raw_event["requestContext"]["eventType"] + assert request_context.message_direction == raw_event["requestContext"]["messageDirection"] + assert request_context.message_id == raw_event["requestContext"]["messageId"] + assert request_context.route_key == raw_event["requestContext"]["routeKey"] + + assert parsed_event.body == raw_event["body"] + assert parsed_event.is_base64_encoded == raw_event["isBase64Encoded"] + + +# not sure you can send an empty body TBH but it was a test in api gw so i kept it here, needs verification +def test_apigw_websocket_message_event_empty_body(): + event = load_event("apiGatewayWebSocketApiMessage.json") + event["body"] = None + parse(event=event, model=APIGatewayWebSocketMessageEventModel) + + +def test_apigw_websocket_connect_event(): + raw_event = load_event("apiGatewayWebSocketApiConnect.json") + parsed_event: APIGatewayWebSocketConnectEventModel = APIGatewayWebSocketConnectEventModel(**raw_event) + + request_context = parsed_event.request_context + assert request_context.api_id == raw_event["requestContext"]["apiId"] + assert request_context.domain_name == raw_event["requestContext"]["domainName"] + assert request_context.extended_request_id == raw_event["requestContext"]["extendedRequestId"] + + identity = request_context.identity + assert str(identity.source_ip) == f'{raw_event["requestContext"]["identity"]["sourceIp"]}/32' + + assert request_context.request_id == raw_event["requestContext"]["requestId"] + assert request_context.request_time == raw_event["requestContext"]["requestTime"] + convert_time = int(round(request_context.request_time_epoch.timestamp() * 1000)) + assert convert_time == 1731324924561 + assert request_context.stage == raw_event["requestContext"]["stage"] + convert_time = int(round(request_context.connected_at.timestamp() * 1000)) + assert convert_time == 1731324924553 + assert request_context.connection_id == raw_event["requestContext"]["connectionId"] + assert request_context.event_type == raw_event["requestContext"]["eventType"] + assert request_context.message_direction == raw_event["requestContext"]["messageDirection"] + assert request_context.route_key == raw_event["requestContext"]["routeKey"] + + assert parsed_event.is_base64_encoded == raw_event["isBase64Encoded"] + assert parsed_event.headers == raw_event["headers"] + assert parsed_event.multi_value_headers == raw_event["multiValueHeaders"] + + +def test_apigw_websocket_disconnect_event(): + raw_event = load_event("apiGatewayWebSocketApiDisconnect.json") + parsed_event: APIGatewayWebSocketDisconnectEventModel = APIGatewayWebSocketDisconnectEventModel(**raw_event) + + request_context = parsed_event.request_context + assert request_context.api_id == raw_event["requestContext"]["apiId"] + assert request_context.domain_name == raw_event["requestContext"]["domainName"] + assert request_context.extended_request_id == raw_event["requestContext"]["extendedRequestId"] + + identity = request_context.identity + assert str(identity.source_ip) == f'{raw_event["requestContext"]["identity"]["sourceIp"]}/32' + + assert request_context.request_id == raw_event["requestContext"]["requestId"] + assert request_context.request_time == raw_event["requestContext"]["requestTime"] + convert_time = int(round(request_context.request_time_epoch.timestamp() * 1000)) + assert convert_time == 1731333109875 + assert request_context.stage == raw_event["requestContext"]["stage"] + convert_time = int(round(request_context.connected_at.timestamp() * 1000)) + assert convert_time == 1731332735513 + assert request_context.connection_id == raw_event["requestContext"]["connectionId"] + assert request_context.event_type == raw_event["requestContext"]["eventType"] + assert request_context.message_direction == raw_event["requestContext"]["messageDirection"] + assert request_context.route_key == raw_event["requestContext"]["routeKey"] + assert request_context.disconnect_reason == raw_event["requestContext"]["disconnectReason"] + assert request_context.disconnect_status_code == raw_event["requestContext"]["disconnectStatusCode"] + + assert parsed_event.is_base64_encoded == raw_event["isBase64Encoded"] + assert parsed_event.headers == raw_event["headers"] + assert parsed_event.multi_value_headers == raw_event["multiValueHeaders"] \ No newline at end of file