Skip to content

Commit

Permalink
feat(parser): add models for API GW Websockets events (#5597)
Browse files Browse the repository at this point in the history
* feature(parser): Parser models for API GW Websockets Events

* code review fixes

* fix typo in the doc. add optional model

* fix optional field

* change names to snake case

---------

Co-authored-by: Ran Isenberg <[email protected]>
Co-authored-by: Ana Falcão <[email protected]>
Co-authored-by: Ana Falcao <[email protected]>
Co-authored-by: Leandro Damascena <[email protected]>
  • Loading branch information
5 people authored Nov 24, 2024
1 parent d1a58cd commit 20c0b74
Show file tree
Hide file tree
Showing 10 changed files with 347 additions and 1 deletion.
2 changes: 2 additions & 0 deletions aws_lambda_powertools/utilities/parser/envelopes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .apigw import ApiGatewayEnvelope
from .apigw_websocket import ApiGatewayWebSocketEnvelope
from .apigwv2 import ApiGatewayV2Envelope
from .base import BaseEnvelope
from .bedrock_agent import BedrockAgentEnvelope
Expand All @@ -17,6 +18,7 @@
__all__ = [
"ApiGatewayEnvelope",
"ApiGatewayV2Envelope",
"ApiGatewayWebSocketEnvelope",
"BedrockAgentEnvelope",
"CloudWatchLogsEnvelope",
"DynamoDBStreamEnvelope",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any

from aws_lambda_powertools.utilities.parser.envelopes.base import BaseEnvelope
from aws_lambda_powertools.utilities.parser.models import APIGatewayWebSocketMessageEventModel

if TYPE_CHECKING:
from aws_lambda_powertools.utilities.parser.types import Model

logger = logging.getLogger(__name__)


class ApiGatewayWebSocketEnvelope(BaseEnvelope):
"""API Gateway WebSockets envelope to extract data within body key of messages routes
(not disconnect or connect)"""

def parse(self, data: dict[str, Any] | Any | None, model: type[Model]) -> Model | None:
"""Parses data found with model provided
Parameters
----------
data : dict
Lambda event to be parsed
model : type[Model]
Data model provided to parse after extracting data using envelope
Returns
-------
Any
Parsed detail payload with model provided
"""
logger.debug(
f"Parsing incoming data with Api Gateway WebSockets model {APIGatewayWebSocketMessageEventModel}",
)
parsed_envelope: APIGatewayWebSocketMessageEventModel = APIGatewayWebSocketMessageEventModel.model_validate(
data,
)
logger.debug(f"Parsing event payload in `detail` with {model}")
return self._parse(data=parsed_envelope.body, model=model)
18 changes: 18 additions & 0 deletions aws_lambda_powertools/utilities/parser/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@
APIGatewayEventRequestContext,
APIGatewayProxyEventModel,
)
from .apigw_websocket import (
APIGatewayWebSocketConnectEventModel,
APIGatewayWebSocketConnectEventRequestContext,
APIGatewayWebSocketDisconnectEventModel,
APIGatewayWebSocketDisconnectEventRequestContext,
APIGatewayWebSocketEventIdentity,
APIGatewayWebSocketEventRequestContextBase,
APIGatewayWebSocketMessageEventModel,
APIGatewayWebSocketMessageEventRequestContext,
)
from .apigwv2 import (
ApiGatewayAuthorizerRequestV2,
APIGatewayProxyEventV2Model,
Expand Down Expand Up @@ -105,6 +115,14 @@
__all__ = [
"APIGatewayProxyEventV2Model",
"ApiGatewayAuthorizerRequestV2",
"APIGatewayWebSocketEventIdentity",
"APIGatewayWebSocketMessageEventModel",
"APIGatewayWebSocketMessageEventRequestContext",
"APIGatewayWebSocketConnectEventModel",
"APIGatewayWebSocketConnectEventRequestContext",
"APIGatewayWebSocketDisconnectEventRequestContext",
"APIGatewayWebSocketDisconnectEventModel",
"APIGatewayWebSocketEventRequestContextBase",
"RequestContextV2",
"RequestContextV2Http",
"RequestContextV2Authorizer",
Expand Down
63 changes: 63 additions & 0 deletions aws_lambda_powertools/utilities/parser/models/apigw_websocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from datetime import datetime
from typing import Dict, List, Literal, Optional, Type, Union

from pydantic import BaseModel, Field
from pydantic.networks import IPvAnyNetwork


class APIGatewayWebSocketEventIdentity(BaseModel):
source_ip: IPvAnyNetwork = Field(alias="sourceIp")
user_agent: Optional[str] = Field(None, alias="userAgent")

class APIGatewayWebSocketEventRequestContextBase(BaseModel):
extended_request_id: str = Field(alias="extendedRequestId")
request_time: str = Field(alias="requestTime")
stage: str = Field(alias="stage")
connected_at: datetime = Field(alias="connectedAt")
request_time_epoch: datetime = Field(alias="requestTimeEpoch")
identity: APIGatewayWebSocketEventIdentity = Field(alias="identity")
request_id: str = Field(alias="requestId")
domain_name: str = Field(alias="domainName")
connection_id: str = Field(alias="connectionId")
api_id: str = Field(alias="apiId")


class APIGatewayWebSocketMessageEventRequestContext(APIGatewayWebSocketEventRequestContextBase):
route_key: str = Field(alias="routeKey")
message_id: str = Field(alias="messageId")
event_type: Literal["MESSAGE"] = Field(alias="eventType")
message_direction: Literal["IN", "OUT"] = Field(alias="messageDirection")


class APIGatewayWebSocketConnectEventRequestContext(APIGatewayWebSocketEventRequestContextBase):
route_key: Literal["$connect"] = Field(alias="routeKey")
event_type: Literal["CONNECT"] = Field(alias="eventType")
message_direction: Literal["IN"] = Field(alias="messageDirection")


class APIGatewayWebSocketDisconnectEventRequestContext(APIGatewayWebSocketEventRequestContextBase):
route_key: Literal["$disconnect"] = Field(alias="routeKey")
disconnect_status_code: int = Field(alias="disconnectStatusCode")
event_type: Literal["DISCONNECT"] = Field(alias="eventType")
message_direction: Literal["IN"] = Field(alias="messageDirection")
disconnect_reason: str = Field(alias="disconnectReason")


class APIGatewayWebSocketConnectEventModel(BaseModel):
headers: Dict[str, str] = Field(alias="headers")
multi_value_headers: Dict[str, List[str]] = Field(alias="multiValueHeaders")
request_context: APIGatewayWebSocketConnectEventRequestContext = Field(alias="requestContext")
is_base64_encoded: bool = Field(alias="isBase64Encoded")


class APIGatewayWebSocketDisconnectEventModel(BaseModel):
headers: Dict[str, str] = Field(alias="headers")
multi_value_headers: Dict[str, List[str]] = Field(alias="multiValueHeaders")
request_context: APIGatewayWebSocketDisconnectEventRequestContext = Field(alias="requestContext")
is_base64_encoded: bool = Field(alias="isBase64Encoded")


class APIGatewayWebSocketMessageEventModel(BaseModel):
request_context: APIGatewayWebSocketMessageEventRequestContext = Field(alias="requestContext")
is_base64_encoded: bool = Field(alias="isBase64Encoded")
body: Optional[Union[str, Type[BaseModel]]] = Field(None, alias="body")
6 changes: 5 additions & 1 deletion docs/utilities/parser.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ The example above uses `SqsModel`. Other built-in models can be found below.
| **ApiGatewayAuthorizerRequest** | Lambda Event Source payload for Amazon API Gateway Lambda Authorizer with Request |
| **APIGatewayProxyEventV2Model** | Lambda Event Source payload for Amazon API Gateway v2 payload |
| **ApiGatewayAuthorizerRequestV2** | Lambda Event Source payload for Amazon API Gateway v2 Lambda Authorizer |
| **APIGatewayWebSocketMessageEventModel** | Lambda Event Source payload for Amazon API Gateway WebSocket API message body |
| **APIGatewayWebSocketConnectEventModel** | Lambda Event Source payload for Amazon API Gateway WebSocket API $connect message |
| **APIGatewayWebSocketDisconnectEventModel** | Lambda Event Source payload for Amazon API Gateway WebSocket API $disconnect message |
| **BedrockAgentEventModel** | Lambda Event Source payload for Bedrock Agents |
| **CloudFormationCustomResourceCreateModel** | Lambda Event Source payload for AWS CloudFormation `CREATE` operation |
| **CloudFormationCustomResourceUpdateModel** | Lambda Event Source payload for AWS CloudFormation `UPDATE` operation |
Expand Down Expand Up @@ -188,8 +191,9 @@ You can use pre-built envelopes provided by the Parser to extract and parse spec
| **KinesisFirehoseEnvelope** | 1. Parses data using `KinesisFirehoseModel` which will base64 decode it. ``2. Parses records in in` Records` key using your model`` and returns them in a list. | `List[Model]` |
| **SnsEnvelope** | 1. Parses data using `SnsModel`. ``2. Parses records in `body` key using your model`` and return them in a list. | `List[Model]` |
| **SnsSqsEnvelope** | 1. Parses data using `SqsModel`. `` 2. Parses SNS records in `body` key using `SnsNotificationModel`. `` 3. Parses data in `Message` key using your model and return them in a list. | `List[Model]` |
| **ApiGatewayEnvelope** | 1. Parses data using `APIGatewayProxyEventModel`. ``2. Parses `body` key using your model`` and returns it. | `Model` |
| **ApiGatewayV2Envelope** | 1. Parses data using `APIGatewayProxyEventV2Model`. ``2. Parses `body` key using your model`` and returns it. | `Model` |
| **ApiGatewayEnvelope** | 1. Parses data using `APIGatewayProxyEventModel`. ``2. Parses `body` key using your model`` and returns it. | `Model` |
| **ApiGatewayWebSocketEnvelope** | 1. Parses data using `APIGatewayWebSocketMessageEventModel`. ``2. Parses `body` key using your model`` and returns it. | `Model` |
| **LambdaFunctionUrlEnvelope** | 1. Parses data using `LambdaFunctionUrlModel`. ``2. Parses `body` key using your model`` and returns it. | `Model` |
| **KafkaEnvelope** | 1. Parses data using `KafkaRecordModel`. ``2. Parses `value` key using your model`` and returns it. | `Model` |
| **VpcLatticeEnvelope** | 1. Parses data using `VpcLatticeModel`. ``2. Parses `value` key using your model`` and returns it. | `Model` |
Expand Down
40 changes: 40 additions & 0 deletions tests/events/apiGatewayWebSocketApiConnect.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"headers": {
"Host": "fjnq7njcv2.execute-api.us-east-1.amazonaws.com",
"Sec-WebSocket-Extensions": "permessage-deflate; client_max_window_bits",
"Sec-WebSocket-Key": "+W5xw47OHh3OTFsWKjGu9Q==",
"Sec-WebSocket-Version": "13",
"X-Amzn-Trace-Id": "Root=1-6731ebfc-08e1e656421db73c5d2eef31",
"X-Forwarded-For": "166.90.225.1",
"X-Forwarded-Port": "443",
"X-Forwarded-Proto": "https"
},
"multiValueHeaders": {
"Host": ["fjnq7njcv2.execute-api.us-east-1.amazonaws.com"],
"Sec-WebSocket-Extensions": ["permessage-deflate; client_max_window_bits"],
"Sec-WebSocket-Key": ["+W5xw47OHh3OTFsWKjGu9Q=="],
"Sec-WebSocket-Version": ["13"],
"X-Amzn-Trace-Id": ["Root=1-6731ebfc-08e1e656421db73c5d2eef31"],
"X-Forwarded-For": ["166.90.225.1"],
"X-Forwarded-Port": ["443"],
"X-Forwarded-Proto": ["https"]
},
"requestContext": {
"routeKey": "$connect",
"eventType": "CONNECT",
"extendedRequestId": "BFHPhFe3IAMF95g=",
"requestTime": "11/Nov/2024:11:35:24 +0000",
"messageDirection": "IN",
"stage": "prod",
"connectedAt": 1731324924553,
"requestTimeEpoch": 1731324924561,
"identity": {
"sourceIp": "166.90.225.1"
},
"requestId": "BFHPhFe3IAMF95g=",
"domainName": "asasasas.execute-api.us-east-1.amazonaws.com",
"connectionId": "BFHPhfCWIAMCKlQ=",
"apiId": "asasasas"
},
"isBase64Encoded": false
}
34 changes: 34 additions & 0 deletions tests/events/apiGatewayWebSocketApiDisconnect.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"headers": {
"Host": "asasasas.execute-api.us-east-1.amazonaws.com",
"x-api-key": "",
"X-Forwarded-For": "",
"x-restapi": ""
},
"multiValueHeaders": {
"Host": ["asasasas.execute-api.us-east-1.amazonaws.com"],
"x-api-key": [""],
"X-Forwarded-For": [""],
"x-restapi": [""]
},
"requestContext": {
"routeKey": "$disconnect",
"disconnectStatusCode": 1005,
"eventType": "DISCONNECT",
"extendedRequestId": "BFbOeE87IAMF31w=",
"requestTime": "11/Nov/2024:13:51:49 +0000",
"messageDirection": "IN",
"disconnectReason": "Client-side close frame status not set",
"stage": "prod",
"connectedAt": 1731332735513,
"requestTimeEpoch": 1731333109875,
"identity": {
"sourceIp": "166.90.225.1"
},
"requestId": "BFbOeE87IAMF31w=",
"domainName": "asasasas.execute-api.us-east-1.amazonaws.com",
"connectionId": "BFaT_fALIAMCKug=",
"apiId": "asasasas"
},
"isBase64Encoded": false
}
22 changes: 22 additions & 0 deletions tests/events/apiGatewayWebSocketApiMessage.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"requestContext": {
"routeKey": "chat",
"messageId": "BFaVtfGSIAMCKug=",
"eventType": "MESSAGE",
"extendedRequestId": "BFaVtH2HoAMFZEQ=",
"requestTime": "11/Nov/2024:13:45:46 +0000",
"messageDirection": "IN",
"stage": "prod",
"connectedAt": 1731332735513,
"requestTimeEpoch": 1731332746514,
"identity": {
"sourceIp": "166.90.225.1"
},
"requestId": "BFaVtH2HoAMFZEQ=",
"domainName": "asasasas.execute-api.us-east-1.amazonaws.com",
"connectionId": "BFaT_fALIAMCKug=",
"apiId": "asasasas"
},
"body": "{\"action\": \"chat\", \"message\": \"Hello from client\"}",
"isBase64Encoded": false
}
5 changes: 5 additions & 0 deletions tests/unit/parser/_pydantic/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ class MyApiGatewayBusiness(BaseModel):
username: str


class MyApiGatewayWebSocketBusiness(BaseModel):
message: str
action: str


class MyALambdaFuncUrlBusiness(BaseModel):
message: str
username: str
Expand Down
Loading

0 comments on commit 20c0b74

Please sign in to comment.