-
Notifications
You must be signed in to change notification settings - Fork 401
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'develop' into dependabot/pip/develop/pytest-7.4.2
- Loading branch information
Showing
44 changed files
with
3,743 additions
and
50 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from aws_lambda_powertools.event_handler.middlewares.base import BaseMiddlewareHandler, NextMiddleware | ||
|
||
__all__ = ["BaseMiddlewareHandler", "NextMiddleware"] |
122 changes: 122 additions & 0 deletions
122
aws_lambda_powertools/event_handler/middlewares/base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Generic | ||
|
||
from typing_extensions import Protocol | ||
|
||
from aws_lambda_powertools.event_handler.api_gateway import Response | ||
from aws_lambda_powertools.event_handler.types import EventHandlerInstance | ||
|
||
|
||
class NextMiddleware(Protocol): | ||
def __call__(self, app: EventHandlerInstance) -> Response: | ||
"""Protocol for callback regardless of next_middleware(app), get_response(app) etc""" | ||
... | ||
|
||
def __name__(self) -> str: # noqa A003 | ||
"""Protocol for name of the Middleware""" | ||
... | ||
|
||
|
||
class BaseMiddlewareHandler(Generic[EventHandlerInstance], ABC): | ||
"""Base implementation for Middlewares to run code before and after in a chain. | ||
This is the middleware handler function where middleware logic is implemented. | ||
The next middleware handler is represented by `next_middleware`, returning a Response object. | ||
Examples | ||
-------- | ||
**Correlation ID Middleware** | ||
```python | ||
import requests | ||
from aws_lambda_powertools import Logger | ||
from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response | ||
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware | ||
app = APIGatewayRestResolver() | ||
logger = Logger() | ||
class CorrelationIdMiddleware(BaseMiddlewareHandler): | ||
def __init__(self, header: str): | ||
super().__init__() | ||
self.header = header | ||
def handler(self, app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response: | ||
# BEFORE logic | ||
request_id = app.current_event.request_context.request_id | ||
correlation_id = app.current_event.get_header_value( | ||
name=self.header, | ||
default_value=request_id, | ||
) | ||
# Call next middleware or route handler ('/todos') | ||
response = next_middleware(app) | ||
# AFTER logic | ||
response.headers[self.header] = correlation_id | ||
return response | ||
@app.get("/todos", middlewares=[CorrelationIdMiddleware(header="x-correlation-id")]) | ||
def get_todos(): | ||
todos: requests.Response = requests.get("https://jsonplaceholder.typicode.com/todos") | ||
todos.raise_for_status() | ||
# for brevity, we'll limit to the first 10 only | ||
return {"todos": todos.json()[:10]} | ||
@logger.inject_lambda_context | ||
def lambda_handler(event, context): | ||
return app.resolve(event, context) | ||
``` | ||
""" | ||
|
||
@abstractmethod | ||
def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response: | ||
""" | ||
The Middleware Handler | ||
Parameters | ||
---------- | ||
app: EventHandlerInstance | ||
An instance of an Event Handler that implements ApiGatewayResolver | ||
next_middleware: NextMiddleware | ||
The next middleware handler in the chain | ||
Returns | ||
------- | ||
Response | ||
The response from the next middleware handler in the chain | ||
""" | ||
raise NotImplementedError() | ||
|
||
@property | ||
def __name__(self) -> str: # noqa A003 | ||
return str(self.__class__.__name__) | ||
|
||
def __call__(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response: | ||
""" | ||
The Middleware handler function. | ||
Parameters | ||
---------- | ||
app: ApiGatewayResolver | ||
An instance of an Event Handler that implements ApiGatewayResolver | ||
next_middleware: NextMiddleware | ||
The next middleware handler in the chain | ||
Returns | ||
------- | ||
Response | ||
The response from the next middleware handler in the chain | ||
""" | ||
return self.handler(app, next_middleware) |
124 changes: 124 additions & 0 deletions
124
aws_lambda_powertools/event_handler/middlewares/schema_validation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import logging | ||
from typing import Dict, Optional | ||
|
||
from aws_lambda_powertools.event_handler.api_gateway import Response | ||
from aws_lambda_powertools.event_handler.exceptions import BadRequestError, InternalServerError | ||
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware | ||
from aws_lambda_powertools.event_handler.types import EventHandlerInstance | ||
from aws_lambda_powertools.utilities.validation import validate | ||
from aws_lambda_powertools.utilities.validation.exceptions import InvalidSchemaFormatError, SchemaValidationError | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class SchemaValidationMiddleware(BaseMiddlewareHandler): | ||
"""Middleware to validate API request and response against JSON Schema using the [Validation utility](https://docs.powertools.aws.dev/lambda/python/latest/utilities/validation/). | ||
Examples | ||
-------- | ||
**Validating incoming event** | ||
```python | ||
import requests | ||
from aws_lambda_powertools import Logger | ||
from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response | ||
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware | ||
from aws_lambda_powertools.event_handler.middlewares.schema_validation import SchemaValidationMiddleware | ||
app = APIGatewayRestResolver() | ||
logger = Logger() | ||
json_schema_validation = SchemaValidationMiddleware(inbound_schema=INCOMING_JSON_SCHEMA) | ||
@app.get("/todos", middlewares=[json_schema_validation]) | ||
def get_todos(): | ||
todos: requests.Response = requests.get("https://jsonplaceholder.typicode.com/todos") | ||
todos.raise_for_status() | ||
# for brevity, we'll limit to the first 10 only | ||
return {"todos": todos.json()[:10]} | ||
@logger.inject_lambda_context | ||
def lambda_handler(event, context): | ||
return app.resolve(event, context) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
inbound_schema: Dict, | ||
inbound_formats: Optional[Dict] = None, | ||
outbound_schema: Optional[Dict] = None, | ||
outbound_formats: Optional[Dict] = None, | ||
): | ||
"""See [Validation utility](https://docs.powertools.aws.dev/lambda/python/latest/utilities/validation/) docs for examples on all parameters. | ||
Parameters | ||
---------- | ||
inbound_schema : Dict | ||
JSON Schema to validate incoming event | ||
inbound_formats : Optional[Dict], optional | ||
Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None | ||
JSON Schema to validate outbound event, by default None | ||
outbound_formats : Optional[Dict], optional | ||
Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None | ||
""" # noqa: E501 | ||
super().__init__() | ||
self.inbound_schema = inbound_schema | ||
self.inbound_formats = inbound_formats | ||
self.outbound_schema = outbound_schema | ||
self.outbound_formats = outbound_formats | ||
|
||
def bad_response(self, error: SchemaValidationError) -> Response: | ||
message: str = f"Bad Response: {error.message}" | ||
logger.debug(message) | ||
raise BadRequestError(message) | ||
|
||
def bad_request(self, error: SchemaValidationError) -> Response: | ||
message: str = f"Bad Request: {error.message}" | ||
logger.debug(message) | ||
raise BadRequestError(message) | ||
|
||
def bad_config(self, error: InvalidSchemaFormatError) -> Response: | ||
logger.debug(f"Invalid Schema Format: {error}") | ||
raise InternalServerError("Internal Server Error") | ||
|
||
def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response: | ||
"""Validates incoming JSON payload (body) against JSON Schema provided. | ||
Parameters | ||
---------- | ||
app : EventHandlerInstance | ||
An instance of an Event Handler | ||
next_middleware : NextMiddleware | ||
Callable to get response from the next middleware or route handler in the chain | ||
Returns | ||
------- | ||
Response | ||
It can return three types of response objects | ||
- Original response: Propagates HTTP response returned from the next middleware if validation succeeds | ||
- HTTP 400: Payload or response failed JSON Schema validation | ||
- HTTP 500: JSON Schema provided has incorrect format | ||
""" | ||
try: | ||
validate(event=app.current_event.json_body, schema=self.inbound_schema, formats=self.inbound_formats) | ||
except SchemaValidationError as error: | ||
return self.bad_request(error) | ||
except InvalidSchemaFormatError as error: | ||
return self.bad_config(error) | ||
|
||
result = next_middleware(app) | ||
|
||
if self.outbound_formats is not None: | ||
try: | ||
validate(event=result.body, schema=self.inbound_schema, formats=self.inbound_formats) | ||
except SchemaValidationError as error: | ||
return self.bad_response(error) | ||
except InvalidSchemaFormatError as error: | ||
return self.bad_config(error) | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from typing import TypeVar | ||
|
||
from aws_lambda_powertools.event_handler import ApiGatewayResolver | ||
|
||
EventHandlerInstance = TypeVar("EventHandlerInstance", bound=ApiGatewayResolver) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,14 @@ | ||
import sys | ||
from typing import Any, Callable, Dict, List, TypeVar, Union | ||
|
||
AnyCallableT = TypeVar("AnyCallableT", bound=Callable[..., Any]) # noqa: VNE001 | ||
# JSON primitives only, mypy doesn't support recursive tho | ||
JSONType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]] | ||
|
||
|
||
if sys.version_info >= (3, 8): | ||
from typing import Protocol | ||
else: | ||
from typing_extensions import Protocol | ||
|
||
__all__ = ["Protocol"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.