Skip to content

Commit

Permalink
Merge branch 'develop' into dependabot/pip/develop/pytest-7.4.2
Browse files Browse the repository at this point in the history
  • Loading branch information
leandrodamascena authored Sep 7, 2023
2 parents 98a0fc7 + 9625d37 commit 7990662
Show file tree
Hide file tree
Showing 44 changed files with 3,743 additions and 50 deletions.
410 changes: 381 additions & 29 deletions aws_lambda_powertools/event_handler/api_gateway.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions aws_lambda_powertools/event_handler/middlewares/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from aws_lambda_powertools.event_handler.middlewares.base import BaseMiddlewareHandler, NextMiddleware

__all__ = ["BaseMiddlewareHandler", "NextMiddleware"]
122 changes: 122 additions & 0 deletions aws_lambda_powertools/event_handler/middlewares/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from abc import ABC, abstractmethod
from typing import Generic

from typing_extensions import Protocol

from aws_lambda_powertools.event_handler.api_gateway import Response
from aws_lambda_powertools.event_handler.types import EventHandlerInstance


class NextMiddleware(Protocol):
def __call__(self, app: EventHandlerInstance) -> Response:
"""Protocol for callback regardless of next_middleware(app), get_response(app) etc"""
...

def __name__(self) -> str: # noqa A003
"""Protocol for name of the Middleware"""
...


class BaseMiddlewareHandler(Generic[EventHandlerInstance], ABC):
"""Base implementation for Middlewares to run code before and after in a chain.
This is the middleware handler function where middleware logic is implemented.
The next middleware handler is represented by `next_middleware`, returning a Response object.
Examples
--------
**Correlation ID Middleware**
```python
import requests
from aws_lambda_powertools import Logger
from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
app = APIGatewayRestResolver()
logger = Logger()
class CorrelationIdMiddleware(BaseMiddlewareHandler):
def __init__(self, header: str):
super().__init__()
self.header = header
def handler(self, app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response:
# BEFORE logic
request_id = app.current_event.request_context.request_id
correlation_id = app.current_event.get_header_value(
name=self.header,
default_value=request_id,
)
# Call next middleware or route handler ('/todos')
response = next_middleware(app)
# AFTER logic
response.headers[self.header] = correlation_id
return response
@app.get("/todos", middlewares=[CorrelationIdMiddleware(header="x-correlation-id")])
def get_todos():
todos: requests.Response = requests.get("https://jsonplaceholder.typicode.com/todos")
todos.raise_for_status()
# for brevity, we'll limit to the first 10 only
return {"todos": todos.json()[:10]}
@logger.inject_lambda_context
def lambda_handler(event, context):
return app.resolve(event, context)
```
"""

@abstractmethod
def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
"""
The Middleware Handler
Parameters
----------
app: EventHandlerInstance
An instance of an Event Handler that implements ApiGatewayResolver
next_middleware: NextMiddleware
The next middleware handler in the chain
Returns
-------
Response
The response from the next middleware handler in the chain
"""
raise NotImplementedError()

@property
def __name__(self) -> str: # noqa A003
return str(self.__class__.__name__)

def __call__(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
"""
The Middleware handler function.
Parameters
----------
app: ApiGatewayResolver
An instance of an Event Handler that implements ApiGatewayResolver
next_middleware: NextMiddleware
The next middleware handler in the chain
Returns
-------
Response
The response from the next middleware handler in the chain
"""
return self.handler(app, next_middleware)
124 changes: 124 additions & 0 deletions aws_lambda_powertools/event_handler/middlewares/schema_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import logging
from typing import Dict, Optional

from aws_lambda_powertools.event_handler.api_gateway import Response
from aws_lambda_powertools.event_handler.exceptions import BadRequestError, InternalServerError
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
from aws_lambda_powertools.utilities.validation import validate
from aws_lambda_powertools.utilities.validation.exceptions import InvalidSchemaFormatError, SchemaValidationError

logger = logging.getLogger(__name__)


class SchemaValidationMiddleware(BaseMiddlewareHandler):
"""Middleware to validate API request and response against JSON Schema using the [Validation utility](https://docs.powertools.aws.dev/lambda/python/latest/utilities/validation/).
Examples
--------
**Validating incoming event**
```python
import requests
from aws_lambda_powertools import Logger
from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
from aws_lambda_powertools.event_handler.middlewares.schema_validation import SchemaValidationMiddleware
app = APIGatewayRestResolver()
logger = Logger()
json_schema_validation = SchemaValidationMiddleware(inbound_schema=INCOMING_JSON_SCHEMA)
@app.get("/todos", middlewares=[json_schema_validation])
def get_todos():
todos: requests.Response = requests.get("https://jsonplaceholder.typicode.com/todos")
todos.raise_for_status()
# for brevity, we'll limit to the first 10 only
return {"todos": todos.json()[:10]}
@logger.inject_lambda_context
def lambda_handler(event, context):
return app.resolve(event, context)
```
"""

def __init__(
self,
inbound_schema: Dict,
inbound_formats: Optional[Dict] = None,
outbound_schema: Optional[Dict] = None,
outbound_formats: Optional[Dict] = None,
):
"""See [Validation utility](https://docs.powertools.aws.dev/lambda/python/latest/utilities/validation/) docs for examples on all parameters.
Parameters
----------
inbound_schema : Dict
JSON Schema to validate incoming event
inbound_formats : Optional[Dict], optional
Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None
JSON Schema to validate outbound event, by default None
outbound_formats : Optional[Dict], optional
Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None
""" # noqa: E501
super().__init__()
self.inbound_schema = inbound_schema
self.inbound_formats = inbound_formats
self.outbound_schema = outbound_schema
self.outbound_formats = outbound_formats

def bad_response(self, error: SchemaValidationError) -> Response:
message: str = f"Bad Response: {error.message}"
logger.debug(message)
raise BadRequestError(message)

def bad_request(self, error: SchemaValidationError) -> Response:
message: str = f"Bad Request: {error.message}"
logger.debug(message)
raise BadRequestError(message)

def bad_config(self, error: InvalidSchemaFormatError) -> Response:
logger.debug(f"Invalid Schema Format: {error}")
raise InternalServerError("Internal Server Error")

def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
"""Validates incoming JSON payload (body) against JSON Schema provided.
Parameters
----------
app : EventHandlerInstance
An instance of an Event Handler
next_middleware : NextMiddleware
Callable to get response from the next middleware or route handler in the chain
Returns
-------
Response
It can return three types of response objects
- Original response: Propagates HTTP response returned from the next middleware if validation succeeds
- HTTP 400: Payload or response failed JSON Schema validation
- HTTP 500: JSON Schema provided has incorrect format
"""
try:
validate(event=app.current_event.json_body, schema=self.inbound_schema, formats=self.inbound_formats)
except SchemaValidationError as error:
return self.bad_request(error)
except InvalidSchemaFormatError as error:
return self.bad_config(error)

result = next_middleware(app)

if self.outbound_formats is not None:
try:
validate(event=result.body, schema=self.inbound_schema, formats=self.inbound_formats)
except SchemaValidationError as error:
return self.bad_response(error)
except InvalidSchemaFormatError as error:
return self.bad_config(error)

return result
5 changes: 5 additions & 0 deletions aws_lambda_powertools/event_handler/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from typing import TypeVar

from aws_lambda_powertools.event_handler import ApiGatewayResolver

EventHandlerInstance = TypeVar("EventHandlerInstance", bound=ApiGatewayResolver)
9 changes: 9 additions & 0 deletions aws_lambda_powertools/shared/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
import sys
from typing import Any, Callable, Dict, List, TypeVar, Union

AnyCallableT = TypeVar("AnyCallableT", bound=Callable[..., Any]) # noqa: VNE001
# JSON primitives only, mypy doesn't support recursive tho
JSONType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]]


if sys.version_info >= (3, 8):
from typing import Protocol
else:
from typing_extensions import Protocol

__all__ = ["Protocol"]
21 changes: 19 additions & 2 deletions aws_lambda_powertools/utilities/data_classes/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
import json
from collections.abc import Mapping
from typing import Any, Callable, Dict, Iterator, List, Optional
from typing import Any, Callable, Dict, Iterator, List, Optional, overload

from aws_lambda_powertools.shared.headers_serializer import BaseHeadersSerializer
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
Expand Down Expand Up @@ -156,7 +156,24 @@ def get_query_string_value(self, name: str, default_value: Optional[str] = None)
default_value=default_value,
)

# Maintenance: missing @overload to ensure return type is a str when default_value is set
@overload
def get_header_value(
self,
name: str,
default_value: str,
case_sensitive: Optional[bool] = False,
) -> str:
...

@overload
def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: Optional[bool] = False,
) -> Optional[str]:
...

def get_header_value(
self,
name: str,
Expand Down
20 changes: 19 additions & 1 deletion aws_lambda_powertools/utilities/data_classes/vpc_lattice.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, overload

from aws_lambda_powertools.shared.headers_serializer import (
BaseHeadersSerializer,
Expand Down Expand Up @@ -91,6 +91,24 @@ def get_query_string_value(self, name: str, default_value: Optional[str] = None)
default_value=default_value,
)

@overload
def get_header_value(
self,
name: str,
default_value: str,
case_sensitive: Optional[bool] = False,
) -> str:
...

@overload
def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: Optional[bool] = False,
) -> Optional[str]:
...

def get_header_value(
self,
name: str,
Expand Down
Loading

0 comments on commit 7990662

Please sign in to comment.