From 9625d373570200064fbca851e48ed2981078a6f2 Mon Sep 17 00:00:00 2001 From: walmsles <2704782+walmsles@users.noreply.github.com> Date: Fri, 8 Sep 2023 06:22:26 +1000 Subject: [PATCH] feat(event_handler): add Middleware support for REST Event Handler (#2917) Co-authored-by: Heitor Lessa Co-authored-by: Heitor Lessa Co-authored-by: Leandro Damascena --- .../event_handler/api_gateway.py | 410 ++++- .../event_handler/middlewares/__init__.py | 3 + .../event_handler/middlewares/base.py | 122 ++ .../middlewares/schema_validation.py | 124 ++ aws_lambda_powertools/event_handler/types.py | 5 + aws_lambda_powertools/shared/types.py | 9 + .../utilities/data_classes/common.py | 21 +- .../utilities/data_classes/vpc_lattice.py | 20 +- docs/core/event_handler/api_gateway.md | 247 ++- docs/diagram_src/api-middlewares.drawio | 1534 +++++++++++++++++ .../middlewares_catch_exception-dark.svg | 4 + .../middlewares_catch_exception-light.svg | 4 + ...middlewares_catch_route_exception-dark.svg | 4 + ...iddlewares_catch_route_exception-light.svg | 4 + docs/media/middlewares_early_return-dark.svg | 4 + docs/media/middlewares_early_return-light.svg | 4 + docs/media/middlewares_early_return.svg | 3 + .../middlewares_normal_processing-dark.svg | 4 + .../middlewares_normal_processing-light.svg | 4 + .../middlewares_unhandled_exception-dark.svg | 4 + .../middlewares_unhandled_exception-light.svg | 4 + ...lewares_unhandled_route_exception-dark.svg | 4 + ...ewares_unhandled_route_exception-light.svg | 4 + .../src/middleware_early_return.py | 28 + .../src/middleware_early_return_output.json | 6 + .../src/middleware_extending_middlewares.py | 47 + .../src/middleware_getting_started.py | 40 + .../middleware_getting_started_output.json | 13 + .../src/middleware_global_middlewares.py | 24 + .../middleware_global_middlewares_module.py | 41 + .../event_handler_rest/src/split_route.py | 2 +- .../src/split_route_module.py | 6 +- .../src/split_route_prefix_module.py | 6 +- tests/events/apiGatewayProxyEvent.json | 2 +- tests/events/apiGatewayProxyOtherEvent.json | 81 + tests/events/apiGatewayProxyV2Event_GET.json | 68 + .../apiGatewayProxyV2OtherGetEvent.json | 68 + ...wayProxyV2SchemaMiddlwareInvalidEvent.json | 69 + ...tewayProxyV2SchemaMiddlwareValidEvent.json | 69 + ...pigatewayeSchemaMiddlwareInvalidEvent.json | 81 + .../apigatewayeSchemaMiddlwareValidEvent.json | 81 + tests/functional/event_handler/conftest.py | 32 + .../event_handler/test_api_gateway.py | 3 +- .../event_handler/test_api_middlewares.py | 480 ++++++ 44 files changed, 3743 insertions(+), 50 deletions(-) create mode 100644 aws_lambda_powertools/event_handler/middlewares/__init__.py create mode 100644 aws_lambda_powertools/event_handler/middlewares/base.py create mode 100644 aws_lambda_powertools/event_handler/middlewares/schema_validation.py create mode 100644 aws_lambda_powertools/event_handler/types.py create mode 100644 docs/diagram_src/api-middlewares.drawio create mode 100644 docs/media/middlewares_catch_exception-dark.svg create mode 100644 docs/media/middlewares_catch_exception-light.svg create mode 100644 docs/media/middlewares_catch_route_exception-dark.svg create mode 100644 docs/media/middlewares_catch_route_exception-light.svg create mode 100644 docs/media/middlewares_early_return-dark.svg create mode 100644 docs/media/middlewares_early_return-light.svg create mode 100644 docs/media/middlewares_early_return.svg create mode 100644 docs/media/middlewares_normal_processing-dark.svg create mode 100644 docs/media/middlewares_normal_processing-light.svg create mode 100644 docs/media/middlewares_unhandled_exception-dark.svg create mode 100644 docs/media/middlewares_unhandled_exception-light.svg create mode 100644 docs/media/middlewares_unhandled_route_exception-dark.svg create mode 100644 docs/media/middlewares_unhandled_route_exception-light.svg create mode 100644 examples/event_handler_rest/src/middleware_early_return.py create mode 100644 examples/event_handler_rest/src/middleware_early_return_output.json create mode 100644 examples/event_handler_rest/src/middleware_extending_middlewares.py create mode 100644 examples/event_handler_rest/src/middleware_getting_started.py create mode 100644 examples/event_handler_rest/src/middleware_getting_started_output.json create mode 100644 examples/event_handler_rest/src/middleware_global_middlewares.py create mode 100644 examples/event_handler_rest/src/middleware_global_middlewares_module.py create mode 100644 tests/events/apiGatewayProxyOtherEvent.json create mode 100644 tests/events/apiGatewayProxyV2Event_GET.json create mode 100644 tests/events/apiGatewayProxyV2OtherGetEvent.json create mode 100644 tests/events/apiGatewayProxyV2SchemaMiddlwareInvalidEvent.json create mode 100644 tests/events/apiGatewayProxyV2SchemaMiddlwareValidEvent.json create mode 100644 tests/events/apigatewayeSchemaMiddlwareInvalidEvent.json create mode 100644 tests/events/apigatewayeSchemaMiddlwareValidEvent.json create mode 100644 tests/functional/event_handler/test_api_middlewares.py diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index b716087d38..2163d7d762 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -9,19 +9,7 @@ from enum import Enum from functools import partial from http import HTTPStatus -from typing import ( - Any, - Callable, - Dict, - List, - Match, - Optional, - Pattern, - Set, - Tuple, - Type, - Union, -) +from typing import Any, Callable, Dict, List, Match, Optional, Pattern, Set, Tuple, Type, Union from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError @@ -218,13 +206,129 @@ def __init__( cors: bool, compress: bool, cache_control: Optional[str], + middlewares: Optional[List[Callable[..., Response]]], ): + """ + + Parameters + ---------- + + method: str + The HTTP method, example "GET" + rule: Pattern + The route rule, example "/my/path" + func: Callable + The route handler function + cors: bool + Whether or not to enable CORS for this route + compress: bool + Whether or not to enable gzip compression for this route + cache_control: Optional[str] + The cache control header value, example "max-age=3600" + middlewares: Optional[List[Callable[..., Response]]] + The list of route middlewares to be called in order. + """ self.method = method.upper() self.rule = rule self.func = func + self._middleware_stack = func self.cors = cors self.compress = compress self.cache_control = cache_control + self.middlewares = middlewares or [] + + # _middleware_stack_built is used to ensure the middleware stack is only built once. + self._middleware_stack_built = False + + def __call__( + self, + router_middlewares: List[Callable], + app: "ApiGatewayResolver", + route_arguments: Dict[str, str], + ) -> Union[Dict, Tuple, Response]: + """Calling the Router class instance will trigger the following actions: + 1. If Route Middleware stack has not been built, build it + 2. Call the Route Middleware stack wrapping the original function + handler with the app and route arguments. + + Parameters + ---------- + router_middlewares: List[Callable] + The list of Router Middlewares (assigned to ALL routes) + app: "ApiGatewayResolver" + The ApiGatewayResolver instance to pass into the middleware stack + route_arguments: Dict[str, str] + The route arguments to pass to the app function (extracted from the Api Gateway + Lambda Message structure from AWS) + + Returns + ------- + Union[Dict, Tuple, Response] + API Response object in ALL cases, except when the original API route + handler is called which may also return a Dict, Tuple, or Response. + """ + + # Save CPU cycles by building middleware stack once + if not self._middleware_stack_built: + self._build_middleware_stack(router_middlewares=router_middlewares) + + # If debug is turned on then output the middleware stack to the console + if app._debug: + print(f"\nProcessing Route:::{self.func.__name__} ({app.context['_path']})") + # Collect ALL middleware for debug printing - include internal _registered_api_adapter + all_middlewares = router_middlewares + self.middlewares + [_registered_api_adapter] + print("\nMiddleware Stack:") + print("=================") + print("\n".join(getattr(item, "__name__", "Unknown") for item in all_middlewares)) + print("=================") + + # Add Route Arguments to app context + app.append_context(_route_args=route_arguments) + + # Call the Middleware Wrapped _call_stack function handler with the app + return self._middleware_stack(app) + + def _build_middleware_stack(self, router_middlewares: List[Callable[..., Any]]) -> None: + """ + Builds the middleware stack for the handler by wrapping each + handler in an instance of MiddlewareWrapper which is used to contain the state + of each middleware step. + + Middleware is represented by a standard Python Callable construct. Any Middleware + handler wanting to short-circuit the middlware call chain can raise an exception + to force the Python call stack created by the handler call-chain to naturally un-wind. + + This becomes a simple concept for developers to understand and reason with - no additional + gymanstics other than plain old try ... except. + + Notes + ----- + The Route Middleware stack is processed in reverse order. This is so the stack of + middleware handlers is applied in the order of being added to the handler. + """ + all_middlewares = router_middlewares + self.middlewares + logger.debug(f"Building middleware stack: {all_middlewares}") + + # IMPORTANT: + # this must be the last middleware in the stack (tech debt for backward + # compatibility purposes) + # + # This adapter will: + # 1. Call the registered API passing only the expected route arguments extracted from the path + # and not the middleware. + # 2. Adapt the response type of the route handler (Union[Dict, Tuple, Response]) + # and normalise into a Response object so middleware will always have a constant signature + all_middlewares.append(_registered_api_adapter) + + # Wrap the original route handler function in the middleware handlers + # using the MiddlewareWrapper class callable construct in reverse order to + # ensure middleware is applied in the order the user defined. + # + # Start with the route function and wrap from last to the first Middleware handler. + for handler in reversed(all_middlewares): + self._middleware_stack = MiddlewareFrame(current_middleware=handler, next_middleware=self._middleware_stack) + + self._middleware_stack_built = True class ResponseBuilder: @@ -268,7 +372,11 @@ def _has_compression_enabled( bool True if compression is enabled and the "gzip" encoding is accepted, False otherwise. """ - encoding: str = event.get_header_value(name="accept-encoding", default_value="", case_sensitive=False) # type: ignore[assignment] # noqa: E501 + encoding: str = event.get_header_value( + name="accept-encoding", + default_value="", + case_sensitive=False, + ) # noqa: E501 if "gzip" in encoding: if response_compression is not None: return response_compression # e.g., Response(compress=False/True)) @@ -322,6 +430,8 @@ class BaseRouter(ABC): current_event: BaseProxyEvent lambda_context: LambdaContext context: dict + _router_middlewares: List[Callable] = [] + processed_stack_frames: List[str] = [] @abstractmethod def route( @@ -331,10 +441,59 @@ def route( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, + middlewares: Optional[List[Callable[..., Any]]] = None, ): raise NotImplementedError() - def get(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None): + def use(self, middlewares: List[Callable[..., Response]]) -> None: + """ + Add one or more global middlewares that run before/after route specific middleware. + + NOTE: Middlewares are called in insertion order. + + Parameters + ---------- + middlewares: List[Callable[..., Response]] + List of global middlewares to be used + + Examples + -------- + + Add middlewares to be used for every request processed by the Router. + + ```python + from aws_lambda_powertools import Logger + from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response + from aws_lambda_powertools.event_handler.middlewares import NextMiddleware + + logger = Logger() + app = APIGatewayRestResolver() + + def log_request_response(app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response: + logger.info("Incoming request", path=app.current_event.path, request=app.current_event.raw_event) + + result = next_middleware(app) + logger.info("Response received", response=result.__dict__) + + return result + + app.use(middlewares=[log_request_response]) + + + def lambda_handler(event, context): + return app.resolve(event, context) + ``` + """ + self._router_middlewares = self._router_middlewares + middlewares + + def get( + self, + rule: str, + cors: Optional[bool] = None, + compress: bool = False, + cache_control: Optional[str] = None, + middlewares: Optional[List[Callable[..., Any]]] = None, + ): """Get route decorator with GET `method` Examples @@ -357,9 +516,16 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "GET", cors, compress, cache_control) + return self.route(rule, "GET", cors, compress, cache_control, middlewares) - def post(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None): + def post( + self, + rule: str, + cors: Optional[bool] = None, + compress: bool = False, + cache_control: Optional[str] = None, + middlewares: Optional[List[Callable[..., Any]]] = None, + ): """Post route decorator with POST `method` Examples @@ -383,9 +549,16 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "POST", cors, compress, cache_control) + return self.route(rule, "POST", cors, compress, cache_control, middlewares) - def put(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None): + def put( + self, + rule: str, + cors: Optional[bool] = None, + compress: bool = False, + cache_control: Optional[str] = None, + middlewares: Optional[List[Callable[..., Any]]] = None, + ): """Put route decorator with PUT `method` Examples @@ -409,7 +582,7 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "PUT", cors, compress, cache_control) + return self.route(rule, "PUT", cors, compress, cache_control, middlewares) def delete( self, @@ -417,6 +590,7 @@ def delete( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, + middlewares: Optional[List[Callable[..., Any]]] = None, ): """Delete route decorator with DELETE `method` @@ -440,7 +614,7 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "DELETE", cors, compress, cache_control) + return self.route(rule, "DELETE", cors, compress, cache_control, middlewares) def patch( self, @@ -448,6 +622,7 @@ def patch( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, + middlewares: Optional[List[Callable]] = None, ): """Patch route decorator with PATCH `method` @@ -474,7 +649,19 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "PATCH", cors, compress, cache_control) + return self.route(rule, "PATCH", cors, compress, cache_control, middlewares) + + def _push_processed_stack_frame(self, frame: str): + """ + Add Current Middleware to the Middleware Stack Frames + The stack frames will be used when exceptions are thrown and Powertools + debug is enabled by developers. + """ + self.processed_stack_frames.append(frame) + + def _reset_processed_stack(self): + """Reset the Processed Stack Frames""" + self.processed_stack_frames.clear() def append_context(self, **additional_context): """Append key=value data as routing context""" @@ -485,6 +672,109 @@ def clear_context(self): self.context.clear() +class MiddlewareFrame: + """ + creates a Middle Stack Wrapper instance to be used as a "Frame" in the overall stack of + middleware functions. Each instance contains the current middleware and the next + middleware function to be called in the stack. + + In this way the middleware stack is constructed in a recursive fashion, with each middleware + calling the next as a simple function call. The actual Python call-stack will contain + each MiddlewareStackWrapper "Frame", meaning any Middleware function can cause the + entire Middleware call chain to be exited early (short-circuited) by raising an exception + or by simply returning early with a custom Response. The decision to short-circuit the middleware + chain is at the user's discretion but instantly available due to the Wrapped nature of the + callable constructs in the Middleware stack and each Middleware function having complete control over + whether the "Next" handler in the stack is called or not. + + Parameters + ---------- + current_middleware : Callable + The current middleware function to be called as a request is processed. + next_middleware : Callable + The next middleware in the middleware stack. + """ + + def __init__( + self, + current_middleware: Callable[..., Any], + next_middleware: Callable[..., Any], + ) -> None: + self.current_middleware: Callable[..., Any] = current_middleware + self.next_middleware: Callable[..., Any] = next_middleware + self._next_middleware_name = next_middleware.__name__ + + @property + def __name__(self) -> str: # noqa: A003 + """Current middleware name + + It ensures backward compatibility with view functions being callable. This + improves debugging since we need both current and next middlewares/callable names. + """ + return self.current_middleware.__name__ + + def __str__(self) -> str: + """Identify current middleware identity and call chain for debugging purposes.""" + middleware_name = self.__name__ + return f"[{middleware_name}] next call chain is {middleware_name} -> {self._next_middleware_name}" + + def __call__(self, app: "ApiGatewayResolver") -> Union[Dict, Tuple, Response]: + """ + Call the middleware Frame to process the request. + + Parameters + ---------- + app: BaseRouter + The router instance + + Returns + ------- + Union[Dict, Tuple, Response] + (tech-debt for backward compatibility). The response type should be a + Response object in all cases excepting when the original API route handler + is called which will return one of 3 outputs. + + """ + # Do debug printing and push processed stack frame AFTER calling middleware + # else the stack frame text of `current calling next` is confusing. + logger.debug("MiddlewareFrame: %s", self) + app._push_processed_stack_frame(str(self)) + + return self.current_middleware(app, self.next_middleware) + + +def _registered_api_adapter( + app: "ApiGatewayResolver", + next_middleware: Callable[..., Any], +) -> Union[Dict, Tuple, Response]: + """ + Calls the registered API using the "_route_args" from the Resolver context to ensure the last call + in the chain will match the API route function signature and ensure that Powertools passes the API + route handler the expected arguments. + + **IMPORTANT: This internal middleware ensures the actual API route is called with the correct call signature + and it MUST be the final frame in the middleware stack. This can only be removed when the API Route + function accepts `app: BaseRouter` as the first argument - which is the breaking change. + + Parameters + ---------- + app: ApiGatewayResolver + The API Gateway resolver + next_middleware: Callable[..., Any] + The function to handle the API + + Returns + ------- + Response + The API Response Object + + """ + route_args: Dict = app.context.get("_route_args", {}) + logger.debug(f"Calling API Route Handler: {route_args}") + + return app._to_response(next_middleware(**route_args)) + + class ApiGatewayResolver(BaseRouter): """API Gateway and ALB proxy resolver @@ -550,6 +840,7 @@ def __init__( self._debug = self._has_debug(debug) self._strip_prefixes = strip_prefixes self.context: Dict = {} # early init as customers might add context before event resolution + self.processed_stack_frames = [] # Allow for a custom serializer or a concise json serialization self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder) @@ -561,6 +852,7 @@ def route( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, + middlewares: Optional[List[Callable[..., Any]]] = None, ): """Route decorator includes parameter `method`""" @@ -573,7 +865,15 @@ def register_resolver(func: Callable): cors_enabled = cors for item in methods: - _route = Route(item, self._compile_regex(rule), func, cors_enabled, compress, cache_control) + _route = Route( + item, + self._compile_regex(rule), + func, + cors_enabled, + compress, + cache_control, + middlewares, + ) # The more specific route wins. # We store dynamic (/studies/{studyid}) and static routes (/studies/fetch) separately. @@ -594,6 +894,7 @@ def register_resolver(func: Callable): if cors_enabled: logger.debug(f"Registering method {item.upper()} to Allow Methods in CORS") self._cors_methods.add(item.upper()) + return func return register_resolver @@ -628,7 +929,16 @@ def resolve(self, event, context) -> Dict[str, Any]: BaseRouter.lambda_context = context response = self._resolve().build(self.current_event, self._cors) + + # Debug print Processed Middlewares + if self._debug: + print("\nProcessed Middlewares:") + print("======================") + print("\n".join(self.processed_stack_frames)) + print("======================") + self.clear_context() + return response def __call__(self, event, context) -> Any: @@ -703,6 +1013,9 @@ def _resolve(self) -> ResponseBuilder: match_results: Optional[Match] = route.rule.match(path) if match_results: logger.debug("Found a registered route. Calling function") + # Add matched Route reference into the Resolver context + self.append_context(_route=route, _path=path) + return self._call_route(route, match_results.groupdict()) # pass fn args logger.debug(f"No match found for path {path} and method {method}") @@ -765,15 +1078,25 @@ def _not_found(self, method: str) -> ResponseBuilder: ), ) - def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder: + def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> ResponseBuilder: """Actually call the matching route with any provided keyword arguments.""" try: - return ResponseBuilder(self._to_response(route.func(**args)), route) + # Reset Processed stack for Middleware (for debugging purposes) + self._reset_processed_stack() + + return ResponseBuilder( + self._to_response( + route(router_middlewares=self._router_middlewares, app=self, route_arguments=route_arguments), + ), + route, + ) except Exception as exc: + # If exception is handled then return the response builder to reduce noise response_builder = self._call_exception_handler(exc, route) if response_builder: return response_builder + logger.exception(exc) if self._debug: # If the user has turned on debug mode, # we'll let the original exception propagate so @@ -874,8 +1197,12 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None # Add reference to parent ApiGatewayResolver to support use cases where people subclass it to add custom logic router.api_resolver = self - # Merge app and router context + logger.debug("Merging App context with Router context") self.context.update(**router.context) + + logger.debug("Appending Router middlewares into App middlewares.") + self._router_middlewares = self._router_middlewares + router._router_middlewares + # use pointer to allow context clearance after event is processed e.g., resolve(evt, ctx) router.context = self.context @@ -887,7 +1214,15 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None rule = prefix if rule == "/" else f"{prefix}{rule}" new_route = (rule, *route[1:]) - self.route(*new_route)(func) + # Middlewares are stored by route separately - must grab them to include + middlewares = router._routes_with_middleware.get(new_route) + + # Need to use "type: ignore" here since mypy does not like a named parameter after + # tuple expansion since may cause duplicate named parameters in the function signature. + # In this case this is not possible since the tuple expansion is from a hashable source + # and the `middlewares` List is a non-hashable structure so will never be included. + # Still need to ignore for mypy checks or will cause failures (false-positive) + self.route(*new_route, middlewares=middlewares)(func) # type: ignore class Router(BaseRouter): @@ -895,6 +1230,7 @@ class Router(BaseRouter): def __init__(self): self._routes: Dict[tuple, Callable] = {} + self._routes_with_middleware: Dict[tuple, List[Callable]] = {} self.api_resolver: Optional[BaseRouter] = None self.context = {} # early init as customers might add context before event resolution @@ -905,11 +1241,26 @@ def route( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, + middlewares: Optional[List[Callable[..., Any]]] = None, ): def register_route(func: Callable): # Convert methods to tuple. It needs to be hashable as its part of the self._routes dict key methods = (method,) if isinstance(method, str) else tuple(method) - self._routes[(rule, methods, cors, compress, cache_control)] = func + + route_key = (rule, methods, cors, compress, cache_control) + + # Collate Middleware for routes + if middlewares is not None: + for handler in middlewares: + if self._routes_with_middleware.get(route_key) is None: + self._routes_with_middleware[route_key] = [handler] + else: + self._routes_with_middleware[route_key].append(handler) + else: + self._routes_with_middleware[route_key] = [] + + self._routes[route_key] = func + return func return register_route @@ -936,9 +1287,10 @@ def route( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, + middlewares: Optional[List[Callable[..., Any]]] = None, ): # NOTE: see #1552 for more context. - return super().route(rule.rstrip("/"), method, cors, compress, cache_control) + return super().route(rule.rstrip("/"), method, cors, compress, cache_control, middlewares) # Override _compile_regex to exclude trailing slashes for route resolution @staticmethod diff --git a/aws_lambda_powertools/event_handler/middlewares/__init__.py b/aws_lambda_powertools/event_handler/middlewares/__init__.py new file mode 100644 index 0000000000..068ce9c04b --- /dev/null +++ b/aws_lambda_powertools/event_handler/middlewares/__init__.py @@ -0,0 +1,3 @@ +from aws_lambda_powertools.event_handler.middlewares.base import BaseMiddlewareHandler, NextMiddleware + +__all__ = ["BaseMiddlewareHandler", "NextMiddleware"] diff --git a/aws_lambda_powertools/event_handler/middlewares/base.py b/aws_lambda_powertools/event_handler/middlewares/base.py new file mode 100644 index 0000000000..32a4486bb3 --- /dev/null +++ b/aws_lambda_powertools/event_handler/middlewares/base.py @@ -0,0 +1,122 @@ +from abc import ABC, abstractmethod +from typing import Generic + +from typing_extensions import Protocol + +from aws_lambda_powertools.event_handler.api_gateway import Response +from aws_lambda_powertools.event_handler.types import EventHandlerInstance + + +class NextMiddleware(Protocol): + def __call__(self, app: EventHandlerInstance) -> Response: + """Protocol for callback regardless of next_middleware(app), get_response(app) etc""" + ... + + def __name__(self) -> str: # noqa A003 + """Protocol for name of the Middleware""" + ... + + +class BaseMiddlewareHandler(Generic[EventHandlerInstance], ABC): + """Base implementation for Middlewares to run code before and after in a chain. + + + This is the middleware handler function where middleware logic is implemented. + The next middleware handler is represented by `next_middleware`, returning a Response object. + + Examples + -------- + + **Correlation ID Middleware** + + ```python + import requests + + from aws_lambda_powertools import Logger + from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response + from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware + + app = APIGatewayRestResolver() + logger = Logger() + + + class CorrelationIdMiddleware(BaseMiddlewareHandler): + def __init__(self, header: str): + super().__init__() + self.header = header + + def handler(self, app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response: + # BEFORE logic + request_id = app.current_event.request_context.request_id + correlation_id = app.current_event.get_header_value( + name=self.header, + default_value=request_id, + ) + + # Call next middleware or route handler ('/todos') + response = next_middleware(app) + + # AFTER logic + response.headers[self.header] = correlation_id + + return response + + + @app.get("/todos", middlewares=[CorrelationIdMiddleware(header="x-correlation-id")]) + def get_todos(): + todos: requests.Response = requests.get("https://jsonplaceholder.typicode.com/todos") + todos.raise_for_status() + + # for brevity, we'll limit to the first 10 only + return {"todos": todos.json()[:10]} + + + @logger.inject_lambda_context + def lambda_handler(event, context): + return app.resolve(event, context) + + ``` + + """ + + @abstractmethod + def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response: + """ + The Middleware Handler + + Parameters + ---------- + app: EventHandlerInstance + An instance of an Event Handler that implements ApiGatewayResolver + next_middleware: NextMiddleware + The next middleware handler in the chain + + Returns + ------- + Response + The response from the next middleware handler in the chain + + """ + raise NotImplementedError() + + @property + def __name__(self) -> str: # noqa A003 + return str(self.__class__.__name__) + + def __call__(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response: + """ + The Middleware handler function. + + Parameters + ---------- + app: ApiGatewayResolver + An instance of an Event Handler that implements ApiGatewayResolver + next_middleware: NextMiddleware + The next middleware handler in the chain + + Returns + ------- + Response + The response from the next middleware handler in the chain + """ + return self.handler(app, next_middleware) diff --git a/aws_lambda_powertools/event_handler/middlewares/schema_validation.py b/aws_lambda_powertools/event_handler/middlewares/schema_validation.py new file mode 100644 index 0000000000..66be47a48f --- /dev/null +++ b/aws_lambda_powertools/event_handler/middlewares/schema_validation.py @@ -0,0 +1,124 @@ +import logging +from typing import Dict, Optional + +from aws_lambda_powertools.event_handler.api_gateway import Response +from aws_lambda_powertools.event_handler.exceptions import BadRequestError, InternalServerError +from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware +from aws_lambda_powertools.event_handler.types import EventHandlerInstance +from aws_lambda_powertools.utilities.validation import validate +from aws_lambda_powertools.utilities.validation.exceptions import InvalidSchemaFormatError, SchemaValidationError + +logger = logging.getLogger(__name__) + + +class SchemaValidationMiddleware(BaseMiddlewareHandler): + """Middleware to validate API request and response against JSON Schema using the [Validation utility](https://docs.powertools.aws.dev/lambda/python/latest/utilities/validation/). + + Examples + -------- + **Validating incoming event** + + ```python + import requests + + from aws_lambda_powertools import Logger + from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response + from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware + from aws_lambda_powertools.event_handler.middlewares.schema_validation import SchemaValidationMiddleware + + app = APIGatewayRestResolver() + logger = Logger() + json_schema_validation = SchemaValidationMiddleware(inbound_schema=INCOMING_JSON_SCHEMA) + + + @app.get("/todos", middlewares=[json_schema_validation]) + def get_todos(): + todos: requests.Response = requests.get("https://jsonplaceholder.typicode.com/todos") + todos.raise_for_status() + + # for brevity, we'll limit to the first 10 only + return {"todos": todos.json()[:10]} + + + @logger.inject_lambda_context + def lambda_handler(event, context): + return app.resolve(event, context) + ``` + """ + + def __init__( + self, + inbound_schema: Dict, + inbound_formats: Optional[Dict] = None, + outbound_schema: Optional[Dict] = None, + outbound_formats: Optional[Dict] = None, + ): + """See [Validation utility](https://docs.powertools.aws.dev/lambda/python/latest/utilities/validation/) docs for examples on all parameters. + + Parameters + ---------- + inbound_schema : Dict + JSON Schema to validate incoming event + inbound_formats : Optional[Dict], optional + Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None + JSON Schema to validate outbound event, by default None + outbound_formats : Optional[Dict], optional + Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None + """ # noqa: E501 + super().__init__() + self.inbound_schema = inbound_schema + self.inbound_formats = inbound_formats + self.outbound_schema = outbound_schema + self.outbound_formats = outbound_formats + + def bad_response(self, error: SchemaValidationError) -> Response: + message: str = f"Bad Response: {error.message}" + logger.debug(message) + raise BadRequestError(message) + + def bad_request(self, error: SchemaValidationError) -> Response: + message: str = f"Bad Request: {error.message}" + logger.debug(message) + raise BadRequestError(message) + + def bad_config(self, error: InvalidSchemaFormatError) -> Response: + logger.debug(f"Invalid Schema Format: {error}") + raise InternalServerError("Internal Server Error") + + def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response: + """Validates incoming JSON payload (body) against JSON Schema provided. + + Parameters + ---------- + app : EventHandlerInstance + An instance of an Event Handler + next_middleware : NextMiddleware + Callable to get response from the next middleware or route handler in the chain + + Returns + ------- + Response + It can return three types of response objects + + - Original response: Propagates HTTP response returned from the next middleware if validation succeeds + - HTTP 400: Payload or response failed JSON Schema validation + - HTTP 500: JSON Schema provided has incorrect format + """ + try: + validate(event=app.current_event.json_body, schema=self.inbound_schema, formats=self.inbound_formats) + except SchemaValidationError as error: + return self.bad_request(error) + except InvalidSchemaFormatError as error: + return self.bad_config(error) + + result = next_middleware(app) + + if self.outbound_formats is not None: + try: + validate(event=result.body, schema=self.inbound_schema, formats=self.inbound_formats) + except SchemaValidationError as error: + return self.bad_response(error) + except InvalidSchemaFormatError as error: + return self.bad_config(error) + + return result diff --git a/aws_lambda_powertools/event_handler/types.py b/aws_lambda_powertools/event_handler/types.py new file mode 100644 index 0000000000..11cd146a57 --- /dev/null +++ b/aws_lambda_powertools/event_handler/types.py @@ -0,0 +1,5 @@ +from typing import TypeVar + +from aws_lambda_powertools.event_handler import ApiGatewayResolver + +EventHandlerInstance = TypeVar("EventHandlerInstance", bound=ApiGatewayResolver) diff --git a/aws_lambda_powertools/shared/types.py b/aws_lambda_powertools/shared/types.py index e4e10192e5..b29c04cbe6 100644 --- a/aws_lambda_powertools/shared/types.py +++ b/aws_lambda_powertools/shared/types.py @@ -1,5 +1,14 @@ +import sys from typing import Any, Callable, Dict, List, TypeVar, Union AnyCallableT = TypeVar("AnyCallableT", bound=Callable[..., Any]) # noqa: VNE001 # JSON primitives only, mypy doesn't support recursive tho JSONType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]] + + +if sys.version_info >= (3, 8): + from typing import Protocol +else: + from typing_extensions import Protocol + +__all__ = ["Protocol"] diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index 7a3fc8ab40..fa7c529604 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -1,7 +1,7 @@ import base64 import json from collections.abc import Mapping -from typing import Any, Callable, Dict, Iterator, List, Optional +from typing import Any, Callable, Dict, Iterator, List, Optional, overload from aws_lambda_powertools.shared.headers_serializer import BaseHeadersSerializer from aws_lambda_powertools.utilities.data_classes.shared_functions import ( @@ -156,7 +156,24 @@ def get_query_string_value(self, name: str, default_value: Optional[str] = None) default_value=default_value, ) - # Maintenance: missing @overload to ensure return type is a str when default_value is set + @overload + def get_header_value( + self, + name: str, + default_value: str, + case_sensitive: Optional[bool] = False, + ) -> str: + ... + + @overload + def get_header_value( + self, + name: str, + default_value: Optional[str] = None, + case_sensitive: Optional[bool] = False, + ) -> Optional[str]: + ... + def get_header_value( self, name: str, diff --git a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py index ffa9cb263a..35194f1f3f 100644 --- a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py +++ b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, overload from aws_lambda_powertools.shared.headers_serializer import ( BaseHeadersSerializer, @@ -91,6 +91,24 @@ def get_query_string_value(self, name: str, default_value: Optional[str] = None) default_value=default_value, ) + @overload + def get_header_value( + self, + name: str, + default_value: str, + case_sensitive: Optional[bool] = False, + ) -> str: + ... + + @overload + def get_header_value( + self, + name: str, + default_value: Optional[str] = None, + case_sensitive: Optional[bool] = False, + ) -> Optional[str]: + ... + def get_header_value( self, name: str, diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index dcfa38f6f9..dd249ec665 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -10,6 +10,7 @@ Event handler for Amazon API Gateway REST and HTTP APIs, Application Loader Bala * Lightweight routing to reduce boilerplate for API Gateway REST/HTTP API, ALB and Lambda Function URLs. * Support for CORS, binary and Gzip compression, Decimals JSON encoding and bring your own JSON serializer * Built-in integration with [Event Source Data Classes utilities](../../utilities/data_classes.md){target="_blank"} for self-documented event schema +* Works with micro function (one or a few routes) and monolithic functions (all routes) ## Getting started @@ -353,14 +354,226 @@ For convenience, these are the default values when using `CORSConfig` to enable ???+ tip "Multiple origins?" If you need to allow multiple origins, pass the additional origins using the `extra_origins` key. -| Key | Value | Note | -| -------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| **[allow_origin](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin){target="_blank" rel="nofollow"}**: `str` | `*` | Only use the default value for development. **Never use `*` for production** unless your use case requires it | -| **[extra_origins](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin){target="_blank" rel="nofollow"}**: `List[str]` | `[]` | Additional origins to be allowed, in addition to the one specified in `allow_origin` | -| **[allow_headers](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers){target="_blank" rel="nofollow"}**: `List[str]` | `[Authorization, Content-Type, X-Amz-Date, X-Api-Key, X-Amz-Security-Token]` | Additional headers will be appended to the default list for your convenience | +| Key | Value | Note | +| ----------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| **[allow_origin](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin){target="_blank" rel="nofollow"}**: `str` | `*` | Only use the default value for development. **Never use `*` for production** unless your use case requires it | +| **[extra_origins](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin){target="_blank" rel="nofollow"}**: `List[str]` | `[]` | Additional origins to be allowed, in addition to the one specified in `allow_origin` | +| **[allow_headers](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers){target="_blank" rel="nofollow"}**: `List[str]` | `[Authorization, Content-Type, X-Amz-Date, X-Api-Key, X-Amz-Security-Token]` | Additional headers will be appended to the default list for your convenience | | **[expose_headers](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers){target="_blank" rel="nofollow"}**: `List[str]` | `[]` | Any additional header beyond the [safe listed by CORS specification](https://developer.mozilla.org/en-US/docs/Glossary/CORS-safelisted_response_header){target="_blank" rel="nofollow"}. | -| **[max_age](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age){target="_blank" rel="nofollow"}**: `int` | `` | Only for pre-flight requests if you choose to have your function to handle it instead of API Gateway | -| **[allow_credentials](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials){target="_blank" rel="nofollow"}**: `bool` | `False` | Only necessary when you need to expose cookies, authorization headers or TLS client certificates. | +| **[max_age](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age){target="_blank" rel="nofollow"}**: `int` | `` | Only for pre-flight requests if you choose to have your function to handle it instead of API Gateway | +| **[allow_credentials](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials){target="_blank" rel="nofollow"}**: `bool` | `False` | Only necessary when you need to expose cookies, authorization headers or TLS client certificates. | + +### Middleware + +```mermaid +stateDiagram + direction LR + + EventHandler: GET /todo + Before: Before response + Next: next_middleware() + MiddlewareLoop: Middleware loop + AfterResponse: After response + MiddlewareFinished: Modified response + Response: Final response + + EventHandler --> Middleware: Has middleware? + state MiddlewareLoop { + direction LR + Middleware --> Before + Before --> Next + Next --> Middleware: More middlewares? + Next --> AfterResponse + } + AfterResponse --> MiddlewareFinished + MiddlewareFinished --> Response + EventHandler --> Response: No middleware +``` + +A middleware is a function you register per route to **intercept** or **enrich** a **request before** or **after** any response. + +Each middleware function receives the following arguments: + +1. **app**. An Event Handler instance so you can access incoming request information, Lambda context, etc. +2. **next_middleware**. A function to get the next middleware or route's response. + +Here's a sample middleware that extracts and injects correlation ID, using `APIGatewayRestResolver` (works for any [Resolver](#event-resolvers)): + +=== "middleware_getting_started.py" + + ```python hl_lines="11 22 29" title="Your first middleware to extract and inject correlation ID" + --8<-- "examples/event_handler_rest/src/middleware_getting_started.py" + ``` + + 1. You can access current request like you normally would. + 2. [Shared context is available](#sharing-contextual-data) to any middleware, Router and App instances. + 3. Get response from the next middleware (if any) or from `/todos` route. + 4. You can manipulate headers, body, or status code before returning it. + 5. Register one or more middlewares in order of execution. + +=== "middleware_getting_started_output.json" + + ```json hl_lines="9-10" + --8<-- "examples/event_handler_rest/src/middleware_getting_started_output.json" + ``` + +#### Global middlewares + +
+![Combining middlewares](../../media/middlewares_normal_processing-light.svg#only-light) +![Combining middlewares](../../media/middlewares_normal_processing-dark.svg#only-dark) + +_Request flowing through multiple registered middlewares_ +
+ +You can use `app.use` to register middlewares that should always run regardless of the route, also known as global middlewares. + +Event Handler **calls global middlewares first**, then middlewares defined at the route level. Here's an example with both middlewares: + +=== "middleware_global_middlewares.py" + + > Use [debug mode](#debug-mode) if you need to log request/response. + + ```python hl_lines="10" + --8<-- "examples/event_handler_rest/src/middleware_global_middlewares.py" + ``` + + 1. A separate file where our middlewares are to keep this example focused. + 2. We register `log_request_response` as a global middleware to run before middleware. + ```mermaid + stateDiagram + direction LR + + GlobalMiddleware: Log request response + RouteMiddleware: Inject correlation ID + EventHandler: Event Handler + + EventHandler --> GlobalMiddleware + GlobalMiddleware --> RouteMiddleware + ``` + +=== "middleware_global_middlewares_module.py" + + ```python hl_lines="8" + --8<-- "examples/event_handler_rest/src/middleware_global_middlewares_module.py" + ``` + +#### Returning early + +
+![Short-circuiting middleware chain](../../media/middlewares_early_return-light.svg#only-light) +![Short-circuiting middleware chain](../../media/middlewares_early_return-dark.svg#only-dark) + +_Interrupting request flow by returning early_ +
+ +Imagine you want to stop processing a request if something is missing, or return immediately if you've seen this request before. + +In these scenarios, you short-circuit the middleware processing logic by returning a [Response object](#fine-grained-responses), or raising a [HTTP Error](#raising-http-errors). This signals to Event Handler to stop and run each `After` logic left in the chain all the way back. + +Here's an example where we prevent any request that doesn't include a correlation ID header: + +=== "middleware_early_return.py" + + ```python hl_lines="12" + --8<-- "examples/event_handler_rest/src/middleware_early_return.py" + ``` + + 1. This middleware will raise an exception if correlation ID header is missing. + 2. This code section will not run if `enforce_correlation_id` returns early. + +=== "middleware_global_middlewares_module.py" + + ```python hl_lines="35 38" + --8<-- "examples/event_handler_rest/src/middleware_global_middlewares_module.py" + ``` + + 1. Raising an exception OR returning a Response object early will short-circuit the middleware chain. + +=== "middleware_early_return_output.json" + + ```python hl_lines="2-3" + --8<-- "examples/event_handler_rest/src/middleware_early_return_output.json" + ``` + +#### Handling exceptions + +!!! tip "For catching exceptions more broadly, we recommend you use the [exception_handler](#exception-handling) decorator." + +By default, any unhandled exception in the middleware chain is eventually propagated as a HTTP 500 back to the client. + +While there isn't anything special on how to use [`try/catch`](https://docs.python.org/3/tutorial/errors.html#handling-exceptions){target="_blank" rel="nofollow"} for middlewares, it is important to visualize how Event Handler deals with them under the following scenarios: + +=== "Unhandled exception from route handler" + + An exception wasn't caught by any middleware during `next_middleware()` block, therefore it propagates all the way back to the client as HTTP 500. + +
+ ![Unhandled exceptions](../../media/middlewares_unhandled_route_exception-light.svg#only-light) + ![Unhandled exceptions](../../media/middlewares_unhandled_route_exception-dark.svg#only-dark) + + _Unhandled route exceptions propagate back to the client_ +
+ +=== "Route handler exception caught by a middleware" + + An exception was only caught by the third middleware, resuming the normal execution of each `After` logic for the second and first middleware. + +
+ ![Middleware handling exceptions](../../media/middlewares_catch_route_exception-light.svg#only-light) + ![Middleware handling exceptions](../../media/middlewares_catch_route_exception-dark.svg#only-dark) + + _Unhandled route exceptions propagate back to the client_ +
+ +=== "Middleware short-circuit by raising exception" + + The third middleware short-circuited the chain by raising an exception and completely skipping the fourth middleware. Because we only caught it in the first middleware, it skipped the `After` logic in the second middleware. + +
+ ![Catching exceptions](../../media/middlewares_catch_exception-light.svg#only-light) + ![Catching exceptions](../../media/middlewares_catch_exception-dark.svg#only-dark) + + _Middleware handling short-circuit exceptions_ +
+ +#### Extending middlewares + +You can implement `BaseMiddlewareHandler` interface to create middlewares that accept configuration, or perform complex operations (_see [being a good citizen section](#being-a-good-citizen)_). + +As a practical example, let's refactor our correlation ID middleware so it accepts a custom HTTP Header to look for. + +```python hl_lines="5 11 23 36" title="Authoring class-based middlewares with BaseMiddlewareHandler" +--8<-- "examples/event_handler_rest/src/middleware_extending_middlewares.py" +``` + +1. You can add any constructor argument like you normally would +2. We implement `handler` just like we [did before](#middleware) with the only exception of the `self` argument, since it's a method. +3. Get response from the next middleware (if any) or from `/todos` route. +4. Register an instance of `CorrelationIdMiddleware`. + +!!! note "Class-based **vs** function-based middlewares" + When registering a middleware, we expect a callable in both cases. For class-based middlewares, `BaseMiddlewareHandler` is doing the work of calling your `handler` method with the correct parameters, hence why we expect an instance of it. + +#### Native middlewares + +These are native middlewares that may become native features depending on customer demand. + +| Middleware | Purpose | +| ---------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------- | +| [SchemaValidationMiddleware](/api/event_handler/middlewares/schema_validation.html){target="_blank"} | Validates API request body and response against JSON Schema, using [Validation utility](../../utilities/validation.md){target="_blank"} | + +#### Being a good citizen + +Middlewares can add subtle improvements to request/response processing, but also add significant complexity if you're not careful. + +Keep the following in mind when authoring middlewares for Event Handler: + +1. **Use built-in features over middlewares**. We include built-in features like [CORS](#cors), [compression](#compress), [binary responses](#binary-responses), [global exception handling](#exception-handling), and [debug mode](#debug-mode) to reduce the need for middlewares. +2. **Call the next middleware**. Return the result of `next_middleware(app)`, or a [Response object](#fine-grained-responses) when you want to [return early](#returning-early). +3. **Keep a lean scope**. Focus on a single task per middleware to ease composability and maintenance. In [debug mode](#debug-mode), we also print out the order middlewares will be triggered to ease operations. +4. **Catch your own exceptions**. Catch and handle known exceptions to your logic. Unless you want to raise [HTTP Errors](#raising-http-errors), or propagate specific exceptions to the client. To catch all and any exceptions, we recommend you use the [exception_handler](#exception-handling) decorator. +5. **Use context to share data**. Use `app.append_context` to [share contextual data](#sharing-contextual-data) between middlewares and route handlers, and `app.context.get(key)` to fetch them. We clear all contextual data at the end of every request. ### Fine grained responses @@ -479,14 +692,19 @@ You can instruct event handler to use a custom serializer to best suit your need ### Split routes with Router -As you grow the number of routes a given Lambda function should handle, it is natural to split routes into separate files to ease maintenance - That's where the `Router` feature is useful. +As you grow the number of routes a given Lambda function should handle, it is natural to either break into smaller Lambda functions, or split routes into separate files to ease maintenance - that's where the `Router` feature is useful. Let's assume you have `split_route.py` as your Lambda function entrypoint and routes in `split_route_module.py`. This is how you'd use the `Router` feature. + + === "split_route_module.py" We import **Router** instead of **APIGatewayRestResolver**; syntax wise is exactly the same. + !!! info + This means all methods, including [middleware](#middleware) will work as usual. + ```python hl_lines="5 13 16 25 28" --8<-- "examples/event_handler_rest/src/split_route_module.py" ``` @@ -495,10 +713,16 @@ Let's assume you have `split_route.py` as your Lambda function entrypoint and ro We use `include_router` method and include all user routers registered in the `router` global object. + !!! note + This method merges routes, [context](#sharing-contextual-data) and [middleware](#middleware) from `Router` into the main resolver instance (`APIGatewayRestResolver()`). + ```python hl_lines="11" --8<-- "examples/event_handler_rest/src/split_route.py" ``` + 1. When using [middleware](#middleware) in both `Router` and main resolver, you can make `Router` middlewares to take precedence by using `include_router` before `app.use()`. + + #### Route prefix In the previous example, `split_route_module.py` routes had a `/todos` prefix. This might grow over time and become repetitive. @@ -536,11 +760,8 @@ You can use specialized router classes according to the type of event that you a You can use `append_context` when you want to share data between your App and Router instances. Any data you share will be available via the `context` dictionary available in your App or Router context. -???+ info - For safety, we always clear any data available in the `context` dictionary after each invocation. - -???+ tip - This can also be useful for middlewares injecting contextual information before a request is processed. +???+ info "We always clear data available in `context` after each invocation." + This can be useful for middlewares injecting contextual information before a request is processed. === "split_route_append_context.py" diff --git a/docs/diagram_src/api-middlewares.drawio b/docs/diagram_src/api-middlewares.drawio new file mode 100644 index 0000000000..874688f1e4 --- /dev/null +++ b/docs/diagram_src/api-middlewares.drawio @@ -0,0 +1,1534 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/docs/media/middlewares_catch_exception-dark.svg b/docs/media/middlewares_catch_exception-dark.svg new file mode 100644 index 0000000000..558f54652d --- /dev/null +++ b/docs/media/middlewares_catch_exception-dark.svg @@ -0,0 +1,4 @@ + + + +
Event
Handler
Event...
API Route
Handler
API Route...
Short-Circuited via Exception
from Middleware-3
Short-Circuited via Exception...
After processing for
Middleware-1 only
After processing for...

Response from
Middleware-1
Response from...
@app.get("/todos")
@app.get("/todos")
Middleware-1
Middleware-1
Middleware-2
Middleware-2
Middleware-3
Middleware-3
Before
Before
After
After
Next
Next
Middleware-4
Middleware-4
Middleware-4 and API Route Handler were not called due to an exception in the previous step
Middleware-4 and API Route Handler were...
Before
Before
After
After
Next (Catch Exceptions)
Next (Catch E...
Before
Before
After
After
Next
Next
Before
Before
After
After
Next
Next
Text is not SVG - cannot display
\ No newline at end of file diff --git a/docs/media/middlewares_catch_exception-light.svg b/docs/media/middlewares_catch_exception-light.svg new file mode 100644 index 0000000000..37fd79ad07 --- /dev/null +++ b/docs/media/middlewares_catch_exception-light.svg @@ -0,0 +1,4 @@ + + + +
Event
Handler
Event...
API Route
Handler
API Route...
Short-Circuited via Exception
from Middleware-3
Short-Circuited via Exception...
After processing for
Middleware-1 only
After processing for...

Response from
Middleware-1
Response from...
@app.get("/todos")
@app.get("/todos")
Middleware-1
Middleware-1
Middleware-2
Middleware-2
Middleware-3
Middleware-3
Before
Before
After
After
Next
Next
Middleware-4
Middleware-4
Middleware-4 and API Route Handler were not called due to an exception in the previous step
Middleware-4 and API Route Handler were...
Before
Before
After
After
Next (Catch Exceptions)
Next (Catch E...
Before
Before
After
After
Next
Next
Before
Before
After
After
Next
Next
Text is not SVG - cannot display
\ No newline at end of file diff --git a/docs/media/middlewares_catch_route_exception-dark.svg b/docs/media/middlewares_catch_route_exception-dark.svg new file mode 100644 index 0000000000..be51e7f8de --- /dev/null +++ b/docs/media/middlewares_catch_route_exception-dark.svg @@ -0,0 +1,4 @@ + + + +
Next
Next
Event
Handler
Event...
API Route
Handler
API Route...
Middleware-1
Middleware-1
Middleware-2
Middleware-2
Middleware-3
Middleware-3
Before
Before
After
After
Next
Next
Middleware-4
Middleware-4
Call stack unwinds to first try..except block within middleware
Middleware-3 returns normal response which is also processed
by Middleware-2 and Middleware-1 after processing
Call stack unwinds to first try..except block within middleware...
@app.get("/todos")
@app.get("/todos")
Before
Before
After
After
Before
Before
After
After
Next
Next
Before
Before
After
After
Next (Catch Exceptions)
Next (Catch E...
Text is not SVG - cannot display
\ No newline at end of file diff --git a/docs/media/middlewares_catch_route_exception-light.svg b/docs/media/middlewares_catch_route_exception-light.svg new file mode 100644 index 0000000000..70787b489a --- /dev/null +++ b/docs/media/middlewares_catch_route_exception-light.svg @@ -0,0 +1,4 @@ + + + +
Next
Next
Event
Handler
Event...
API Route
Handler
API Route...
Middleware-1
Middleware-1
Middleware-2
Middleware-2
Middleware-3
Middleware-3
Before
Before
After
After
Next
Next
Middleware-4
Middleware-4
Call stack unwinds to first try..except block within middleware
Middleware-3 returns normal response which is also processed
by Middleware-2 and Middleware-1 after processing
Call stack unwinds to first try..except block within middleware...
@app.get("/todos")
@app.get("/todos")
Before
Before
After
After
Before
Before
After
After
Next
Next
Before
Before
After
After
Next (Catch Exceptions)
Next (Catch E...
Text is not SVG - cannot display
\ No newline at end of file diff --git a/docs/media/middlewares_early_return-dark.svg b/docs/media/middlewares_early_return-dark.svg new file mode 100644 index 0000000000..3a99cb1a22 --- /dev/null +++ b/docs/media/middlewares_early_return-dark.svg @@ -0,0 +1,4 @@ + + + +
Event
Handler
Event...
API Route
Handler
API Route...
Short-Circuited Response from Middleware-3
Short-Circuited Response from Middleware-3
Middleware-4 and API Route Handler not called due to early
return in previous step
Middleware-4 and API Route Handler not c...
After Middleware processing occurs for
Middleware-2 and Middleware-1 (in that order)
After Middleware processing occurs for...
Before
Before
After
After
Next
Next
Middleware-3
Middleware-3
Middleware-4
Middleware-4
Before
Before
After
After
Next
Next
Middleware-2
Middleware-2
Middleware-1
Middleware-1
Before
Before
After
After
Next
Next
Before
Before
After
After
Next
Next
@app.get("/todos")
@app.get("/todos")
Text is not SVG - cannot display
\ No newline at end of file diff --git a/docs/media/middlewares_early_return-light.svg b/docs/media/middlewares_early_return-light.svg new file mode 100644 index 0000000000..01b01268c7 --- /dev/null +++ b/docs/media/middlewares_early_return-light.svg @@ -0,0 +1,4 @@ + + + +
Event
Handler
Event...
API Route
Handler
API Route...
Short-Circuited Response from Middleware-3
Short-Circuited Response from Middleware-3
Middleware-4 and API Route Handler not called due to early
return in previous step
Middleware-4 and API Route Handler not c...
After Middleware processing occurs for
Middleware-2 and Middleware-1 (in that order)
After Middleware processing occurs for...
Middleware-1
Middleware-1
Middleware-3
Middleware-3
Middleware-4
Middleware-4
Before
Before
After
After
Next
Next
Before
Before
After
After
Next
Next
Middleware-2
Middleware-2
Before
Before
After
After
Next
Next
@app.get("/todos")
@app.get("/todos")
Before
Before
After
After
Next
Next
Text is not SVG - cannot display
\ No newline at end of file diff --git a/docs/media/middlewares_early_return.svg b/docs/media/middlewares_early_return.svg new file mode 100644 index 0000000000..5de73b3941 --- /dev/null +++ b/docs/media/middlewares_early_return.svg @@ -0,0 +1,3 @@ + + +
Event
Handler
Event<br>Handler
API Route
Handler
API Route<br>Handler
Before
Before
After
After
Next
Next
Middleware-1
Middleware-1
Before
Before
After
After
Next
Next
Middleware-2
Middleware-2
Before
Before
After
After
Next
Next
Middleware-3
Middleware-3
Before
Before
After
After
Next
Next
Middleware-4
Middleware-4
Short-Circuited Response from Middleware-3
<font style="font-size: 14px"><b>Short-Circuited Response from Middleware-3</b></font>
Middleware and Route Handler beyond
Middleware-3 not processed
[Not supported by viewer]
Run "After" logic for Middleware-2 then Middleware-1
<span style="font-size: 14px"><b>Run "After" logic for Middleware-2 then Middleware-1<br></b></span>
Not called due to early return in previous step
<b>Not called due to early return in previous step<br></b>
@app.get("/todos")
@app.get("/todos")
\ No newline at end of file diff --git a/docs/media/middlewares_normal_processing-dark.svg b/docs/media/middlewares_normal_processing-dark.svg new file mode 100644 index 0000000000..9e8cc8662b --- /dev/null +++ b/docs/media/middlewares_normal_processing-dark.svg @@ -0,0 +1,4 @@ + + + +
Event
Handler
Event...
API Route
Handler
API Route...
Before
Before
After
After
Next
Next
Middleware-1
Middleware-1
Before
Before
After
After
Next
Next
Middleware-2
Middleware-2
Before
Before
After
After
Next
Next
Middleware-3
Middleware-3
Before
Before
After
After
Next
Next
Middleware-4
Middleware-4
@app.get("/todos")
@app.get("/todos")
Text is not SVG - cannot display
\ No newline at end of file diff --git a/docs/media/middlewares_normal_processing-light.svg b/docs/media/middlewares_normal_processing-light.svg new file mode 100644 index 0000000000..62c5368c7e --- /dev/null +++ b/docs/media/middlewares_normal_processing-light.svg @@ -0,0 +1,4 @@ + + + +
Event
Handler
Event...
API Route
Handler
API Route...
Middleware-1
Middleware-1
Before
Before
After
After
Next
Next
Middleware-2
Middleware-2
Before
Before
After
After
Next
Next
Middleware-3
Middleware-3
Before
Before
After
After
Next
Next
Middleware-4
Middleware-4
@app.get("/todos")
@app.get("/todos")
Before
Before
After
After
Next
Next
Text is not SVG - cannot display
\ No newline at end of file diff --git a/docs/media/middlewares_unhandled_exception-dark.svg b/docs/media/middlewares_unhandled_exception-dark.svg new file mode 100644 index 0000000000..9278db86fa --- /dev/null +++ b/docs/media/middlewares_unhandled_exception-dark.svg @@ -0,0 +1,4 @@ + + + +
Event
Handler
Event...
Short-Circuited from Middleware-3 via an Exception
Short-Circuited from Middleware-3 via an Exception
No other Middleware "after" components will be processed
unless exception is captured from next() call.
No other Middleware "after" components will be processed...
Middleware-4 and API Route Handler were not called due to an exception in the previous step
Middleware-4 and API Route Handler were...

Middleware-1
Middleware-1
Middleware-2
Middleware-2
Middleware-3
Middleware-3
Before
Before
After
After
Next
Next
Middleware-4
Middleware-4
API Route
Handler
API Route...
@app.get("/todos")
@app.get("/todos")
Before
Before
After
After
Next
Next
Before
Before
After
After
Next
Next
Before
Before
After
After
Next
Next
Text is not SVG - cannot display
\ No newline at end of file diff --git a/docs/media/middlewares_unhandled_exception-light.svg b/docs/media/middlewares_unhandled_exception-light.svg new file mode 100644 index 0000000000..b937710991 --- /dev/null +++ b/docs/media/middlewares_unhandled_exception-light.svg @@ -0,0 +1,4 @@ + + + +
Event
Handler
Event...
Short-Circuited from Middleware-3 via an Exception
Short-Circuited from Middleware-3 via an Exception
No other Middleware "after" components will be processed
unless exception is captured from next() call.
No other Middleware "after" components will be processed...

Middleware-1
Middleware-1
Middleware-2
Middleware-2
Middleware-3
Middleware-3
Middleware-4
Middleware-4
Before
Before
After
After
Next
Next
Before
Before
After
After
Next
Next
Before
Before
After
After
Next
Next
Before
Before
After
After
Next
Next
API Route
Handler
API Route...
@app.get("/todos")
@app.get("/todos")
Middleware-4 and API Route Handler were not called due to an exception in the previous step
Middleware-4 and API Route Handler were...
Text is not SVG - cannot display
\ No newline at end of file diff --git a/docs/media/middlewares_unhandled_route_exception-dark.svg b/docs/media/middlewares_unhandled_route_exception-dark.svg new file mode 100644 index 0000000000..fc68c7a97c --- /dev/null +++ b/docs/media/middlewares_unhandled_route_exception-dark.svg @@ -0,0 +1,4 @@ + + + +
Next
Next
Event
Handler
Event...
API Route
Handler
API Route...
Before
Before
After
After
Middleware-1
Middleware-1
Before
Before
After
After
Next
Next
Middleware-2
Middleware-2
Before
Before
After
After
Next
Next
Middleware-3
Middleware-3
Before
Before
After
After
Next
Next
Middleware-4
Middleware-4
Call stack unwinds to Event Handler after it has been captured and handled by Powertools
Call stack unwinds to Event Handler after it has been captured and handled by Powertools
@app.get("/todos")
@app.get("/todos")
Text is not SVG - cannot display
\ No newline at end of file diff --git a/docs/media/middlewares_unhandled_route_exception-light.svg b/docs/media/middlewares_unhandled_route_exception-light.svg new file mode 100644 index 0000000000..0beba842a6 --- /dev/null +++ b/docs/media/middlewares_unhandled_route_exception-light.svg @@ -0,0 +1,4 @@ + + + +
Next
Next
Event
Handler
Event...
API Route
Handler
API Route...
Before
Before
After
After
Middleware-1
Middleware-1
Before
Before
After
After
Next
Next
Middleware-2
Middleware-2
Before
Before
After
After
Next
Next
Middleware-3
Middleware-3
Before
Before
After
After
Next
Next
Middleware-4
Middleware-4
Call stack unwinds to Event Handler after it has been captured and handled by Powertools
Call stack unwinds to Event Handler after it has been captured and handled by Powertools
@app.get("/todos")
@app.get("/todos")
Text is not SVG - cannot display
\ No newline at end of file diff --git a/examples/event_handler_rest/src/middleware_early_return.py b/examples/event_handler_rest/src/middleware_early_return.py new file mode 100644 index 0000000000..1d8b3af159 --- /dev/null +++ b/examples/event_handler_rest/src/middleware_early_return.py @@ -0,0 +1,28 @@ +import middleware_global_middlewares_module +import requests + +from aws_lambda_powertools import Logger +from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response + +app = APIGatewayRestResolver() +logger = Logger() +app.use( + middlewares=[ + middleware_global_middlewares_module.log_request_response, + middleware_global_middlewares_module.enforce_correlation_id, # (1)! + ], +) + + +@app.get("/todos") +def get_todos(): + todos: Response = requests.get("https://jsonplaceholder.typicode.com/todos") # (2)! + todos.raise_for_status() + + # for brevity, we'll limit to the first 10 only + return {"todos": todos.json()[:10]} + + +@logger.inject_lambda_context +def lambda_handler(event, context): + return app.resolve(event, context) diff --git a/examples/event_handler_rest/src/middleware_early_return_output.json b/examples/event_handler_rest/src/middleware_early_return_output.json new file mode 100644 index 0000000000..850b63764d --- /dev/null +++ b/examples/event_handler_rest/src/middleware_early_return_output.json @@ -0,0 +1,6 @@ +{ + "statusCode": 400, + "body": "Correlation ID header is now mandatory", + "isBase64Encoded": false, + "multiValueHeaders": {} +} \ No newline at end of file diff --git a/examples/event_handler_rest/src/middleware_extending_middlewares.py b/examples/event_handler_rest/src/middleware_extending_middlewares.py new file mode 100644 index 0000000000..e492caacf4 --- /dev/null +++ b/examples/event_handler_rest/src/middleware_extending_middlewares.py @@ -0,0 +1,47 @@ +import requests + +from aws_lambda_powertools import Logger +from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response +from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware + +app = APIGatewayRestResolver() +logger = Logger() + + +class CorrelationIdMiddleware(BaseMiddlewareHandler): + def __init__(self, header: str): # (1)! + """Extract and inject correlation ID in response + + Parameters + ---------- + header : str + HTTP Header to extract correlation ID + """ + super().__init__() + self.header = header + + def handler(self, app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response: # (2)! + request_id = app.current_event.request_context.request_id + correlation_id = app.current_event.get_header_value( + name=self.header, + default_value=request_id, + ) + + response = next_middleware(app) # (3)! + response.headers[self.header] = correlation_id + + return response + + +@app.get("/todos", middlewares=[CorrelationIdMiddleware(header="x-correlation-id")]) # (4)! +def get_todos(): + todos: requests.Response = requests.get("https://jsonplaceholder.typicode.com/todos") + todos.raise_for_status() + + # for brevity, we'll limit to the first 10 only + return {"todos": todos.json()[:10]} + + +@logger.inject_lambda_context +def lambda_handler(event, context): + return app.resolve(event, context) diff --git a/examples/event_handler_rest/src/middleware_getting_started.py b/examples/event_handler_rest/src/middleware_getting_started.py new file mode 100644 index 0000000000..6968c85e88 --- /dev/null +++ b/examples/event_handler_rest/src/middleware_getting_started.py @@ -0,0 +1,40 @@ +import requests + +from aws_lambda_powertools import Logger +from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response +from aws_lambda_powertools.event_handler.middlewares import NextMiddleware + +app = APIGatewayRestResolver() +logger = Logger() + + +def inject_correlation_id(app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response: + request_id = app.current_event.request_context.request_id # (1)! + + # Use API Gateway REST API request ID if caller didn't include a correlation ID + correlation_id = app.current_event.headers.get("x-correlation-id", request_id) + + # Inject correlation ID in shared context and Logger + app.append_context(correlation_id=correlation_id) # (2)! + logger.set_correlation_id(request_id) + + # Get response from next middleware OR /todos route + result = next_middleware(app) # (3)! + + # Include Correlation ID in the response back to caller + result.headers["x-correlation-id"] = correlation_id # (4)! + return result + + +@app.get("/todos", middlewares=[inject_correlation_id]) # (5)! +def get_todos(): + todos: Response = requests.get("https://jsonplaceholder.typicode.com/todos") + todos.raise_for_status() + + # for brevity, we'll limit to the first 10 only + return {"todos": todos.json()[:10]} + + +@logger.inject_lambda_context +def lambda_handler(event, context): + return app.resolve(event, context) diff --git a/examples/event_handler_rest/src/middleware_getting_started_output.json b/examples/event_handler_rest/src/middleware_getting_started_output.json new file mode 100644 index 0000000000..aa669c5767 --- /dev/null +++ b/examples/event_handler_rest/src/middleware_getting_started_output.json @@ -0,0 +1,13 @@ +{ + "statusCode": 200, + "body": "{\"todos\":[{\"userId\":1,\"id\":1,\"title\":\"delectus aut autem\",\"completed\":false}]}", + "isBase64Encoded": false, + "multiValueHeaders": { + "Content-Type": [ + "application/json" + ], + "x-correlation-id": [ + "ccd87d70-7a3f-4aec-b1a8-a5a558c239b2" + ] + } +} \ No newline at end of file diff --git a/examples/event_handler_rest/src/middleware_global_middlewares.py b/examples/event_handler_rest/src/middleware_global_middlewares.py new file mode 100644 index 0000000000..11da3e2797 --- /dev/null +++ b/examples/event_handler_rest/src/middleware_global_middlewares.py @@ -0,0 +1,24 @@ +import middleware_global_middlewares_module # (1)! +import requests + +from aws_lambda_powertools import Logger +from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response + +app = APIGatewayRestResolver() +logger = Logger() + +app.use(middlewares=[middleware_global_middlewares_module.log_request_response]) # (2)! + + +@app.get("/todos", middlewares=[middleware_global_middlewares_module.inject_correlation_id]) +def get_todos(): + todos: Response = requests.get("https://jsonplaceholder.typicode.com/todos") + todos.raise_for_status() + + # for brevity, we'll limit to the first 10 only + return {"todos": todos.json()[:10]} + + +@logger.inject_lambda_context +def lambda_handler(event, context): + return app.resolve(event, context) diff --git a/examples/event_handler_rest/src/middleware_global_middlewares_module.py b/examples/event_handler_rest/src/middleware_global_middlewares_module.py new file mode 100644 index 0000000000..81b83c868a --- /dev/null +++ b/examples/event_handler_rest/src/middleware_global_middlewares_module.py @@ -0,0 +1,41 @@ +from aws_lambda_powertools import Logger +from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response +from aws_lambda_powertools.event_handler.middlewares import NextMiddleware + +logger = Logger() + + +def log_request_response(app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response: + logger.info("Incoming request", path=app.current_event.path, request=app.current_event.raw_event) + + result = next_middleware(app) + logger.info("Response received", response=result.__dict__) + + return result + + +def inject_correlation_id(app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response: + request_id = app.current_event.request_context.request_id + + # Use API Gateway REST API request ID if caller didn't include a correlation ID + correlation_id = app.current_event.headers.get("x-correlation-id", request_id) + + # Inject correlation ID in shared context and Logger + app.append_context(correlation_id=correlation_id) + logger.set_correlation_id(request_id) + + # Get response from next middleware OR /todos route + result = next_middleware(app) + + # Include Correlation ID in the response back to caller + result.headers["x-correlation-id"] = correlation_id + return result + + +def enforce_correlation_id(app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response: + # If missing mandatory header raise an error + if not app.current_event.get_header_value("x-correlation-id", case_sensitive=False): + return Response(status_code=400, body="Correlation ID header is now mandatory.") # (1)! + + # Get the response from the next middleware and return it + return next_middleware(app) diff --git a/examples/event_handler_rest/src/split_route.py b/examples/event_handler_rest/src/split_route.py index 6c0933ea08..b9edc1d045 100644 --- a/examples/event_handler_rest/src/split_route.py +++ b/examples/event_handler_rest/src/split_route.py @@ -8,7 +8,7 @@ tracer = Tracer() logger = Logger() app = APIGatewayRestResolver() -app.include_router(split_route_module.router) +app.include_router(split_route_module.router) # (1)! # You can continue to use other utilities just as before diff --git a/examples/event_handler_rest/src/split_route_module.py b/examples/event_handler_rest/src/split_route_module.py index 0462623f90..b6a91b3fb3 100644 --- a/examples/event_handler_rest/src/split_route_module.py +++ b/examples/event_handler_rest/src/split_route_module.py @@ -25,7 +25,11 @@ def get_todos(): @router.get("/todos/") @tracer.capture_method def get_todo_by_id(todo_id: str): # value come as str - api_key: str = router.current_event.get_header_value(name="X-Api-Key", case_sensitive=True, default_value="") # type: ignore[assignment] # sentinel typing # noqa: E501 + api_key: str = router.current_event.get_header_value( + name="X-Api-Key", + case_sensitive=True, + default_value="", + ) # noqa: E501 todos: Response = requests.get(f"{endpoint}/{todo_id}", headers={"X-Api-Key": api_key}) todos.raise_for_status() diff --git a/examples/event_handler_rest/src/split_route_prefix_module.py b/examples/event_handler_rest/src/split_route_prefix_module.py index 41fcf8eed3..aa17e0cd34 100644 --- a/examples/event_handler_rest/src/split_route_prefix_module.py +++ b/examples/event_handler_rest/src/split_route_prefix_module.py @@ -25,7 +25,11 @@ def get_todos(): @router.get("/") @tracer.capture_method def get_todo_by_id(todo_id: str): # value come as str - api_key: str = router.current_event.get_header_value(name="X-Api-Key", case_sensitive=True, default_value="") # type: ignore[assignment] # sentinel typing # noqa: E501 + api_key: str = router.current_event.get_header_value( + name="X-Api-Key", + case_sensitive=True, + default_value="", + ) # sentinel typing # noqa: E501 todos: Response = requests.get(f"{endpoint}/{todo_id}", headers={"X-Api-Key": api_key}) todos.raise_for_status() diff --git a/tests/events/apiGatewayProxyEvent.json b/tests/events/apiGatewayProxyEvent.json index da814c9110..3f095e28e4 100644 --- a/tests/events/apiGatewayProxyEvent.json +++ b/tests/events/apiGatewayProxyEvent.json @@ -78,4 +78,4 @@ "stageVariables": null, "body": "Hello from Lambda!", "isBase64Encoded": false -} +} \ No newline at end of file diff --git a/tests/events/apiGatewayProxyOtherEvent.json b/tests/events/apiGatewayProxyOtherEvent.json new file mode 100644 index 0000000000..5b9d09844a --- /dev/null +++ b/tests/events/apiGatewayProxyOtherEvent.json @@ -0,0 +1,81 @@ +{ + "version": "1.0", + "resource": "/other/path", + "path": "/other/path", + "httpMethod": "GET", + "headers": { + "Header1": "value1", + "Header2": "value2", + "Origin": "https://aws.amazon.com" + }, + "multiValueHeaders": { + "Header1": [ + "value1" + ], + "Header2": [ + "value1", + "value2" + ] + }, + "queryStringParameters": { + "parameter1": "value1", + "parameter2": "value" + }, + "multiValueQueryStringParameters": { + "parameter1": [ + "value1", + "value2" + ], + "parameter2": [ + "value" + ] + }, + "requestContext": { + "accountId": "123456789012", + "apiId": "id", + "authorizer": { + "claims": null, + "scopes": null + }, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "extendedRequestId": "request-id", + "httpMethod": "GET", + "identity": { + "accessKey": null, + "accountId": null, + "caller": null, + "cognitoAuthenticationProvider": null, + "cognitoAuthenticationType": null, + "cognitoIdentityId": null, + "cognitoIdentityPoolId": null, + "principalOrgId": null, + "sourceIp": "192.168.0.1/32", + "user": null, + "userAgent": "user-agent", + "userArn": null, + "clientCert": { + "clientCertPem": "CERT_CONTENT", + "subjectDN": "www.example.com", + "issuerDN": "Example issuer", + "serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1", + "validity": { + "notBefore": "May 28 12:30:02 2019 GMT", + "notAfter": "Aug 5 09:36:04 2021 GMT" + } + } + }, + "path": "/other/path", + "protocol": "HTTP/1.1", + "requestId": "id=", + "requestTime": "04/Mar/2020:19:15:17 +0000", + "requestTimeEpoch": 1583349317135, + "resourceId": null, + "resourcePath": "/other/path", + "stage": "$default" + }, + "pathParameters": null, + "stageVariables": null, + "body": "Hello from Lambda!", + "isBase64Encoded": false +} diff --git a/tests/events/apiGatewayProxyV2Event_GET.json b/tests/events/apiGatewayProxyV2Event_GET.json new file mode 100644 index 0000000000..f411ea655d --- /dev/null +++ b/tests/events/apiGatewayProxyV2Event_GET.json @@ -0,0 +1,68 @@ +{ + "version": "2.0", + "routeKey": "$default", + "rawPath": "/my/path", + "rawQueryString": "parameter1=value1¶meter1=value2¶meter2=value", + "cookies": [ + "cookie1", + "cookie2" + ], + "headers": { + "Header1": "value1", + "Header2": "value1,value2" + }, + "queryStringParameters": { + "parameter1": "value1,value2", + "parameter2": "value" + }, + "requestContext": { + "accountId": "123456789012", + "apiId": "api-id", + "authentication": { + "clientCert": { + "clientCertPem": "CERT_CONTENT", + "subjectDN": "www.example.com", + "issuerDN": "Example issuer", + "serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1", + "validity": { + "notBefore": "May 28 12:30:02 2019 GMT", + "notAfter": "Aug 5 09:36:04 2021 GMT" + } + } + }, + "authorizer": { + "jwt": { + "claims": { + "claim1": "value1", + "claim2": "value2" + }, + "scopes": [ + "scope1", + "scope2" + ] + } + }, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "http": { + "method": "GET", + "path": "/my/path", + "protocol": "HTTP/1.1", + "sourceIp": "192.168.0.1/32", + "userAgent": "agent" + }, + "requestId": "id", + "routeKey": "$default", + "stage": "$default", + "time": "12/Mar/2020:19:03:58 +0000", + "timeEpoch": 1583348638390 + }, + "pathParameters": { + "parameter1": "value1" + }, + "isBase64Encoded": false, + "stageVariables": { + "stageVariable1": "value1", + "stageVariable2": "value2" + } +} diff --git a/tests/events/apiGatewayProxyV2OtherGetEvent.json b/tests/events/apiGatewayProxyV2OtherGetEvent.json new file mode 100644 index 0000000000..b9bd88f1c2 --- /dev/null +++ b/tests/events/apiGatewayProxyV2OtherGetEvent.json @@ -0,0 +1,68 @@ +{ + "version": "2.0", + "routeKey": "$default", + "rawPath": "/other/path", + "rawQueryString": "parameter1=value1¶meter1=value2¶meter2=value", + "cookies": [ + "cookie1", + "cookie2" + ], + "headers": { + "Header1": "value1", + "Header2": "value1,value2" + }, + "queryStringParameters": { + "parameter1": "value1,value2", + "parameter2": "value" + }, + "requestContext": { + "accountId": "123456789012", + "apiId": "api-id", + "authentication": { + "clientCert": { + "clientCertPem": "CERT_CONTENT", + "subjectDN": "www.example.com", + "issuerDN": "Example issuer", + "serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1", + "validity": { + "notBefore": "May 28 12:30:02 2019 GMT", + "notAfter": "Aug 5 09:36:04 2021 GMT" + } + } + }, + "authorizer": { + "jwt": { + "claims": { + "claim1": "value1", + "claim2": "value2" + }, + "scopes": [ + "scope1", + "scope2" + ] + } + }, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "http": { + "method": "GET", + "path": "/other/path", + "protocol": "HTTP/1.1", + "sourceIp": "192.168.0.1/32", + "userAgent": "agent" + }, + "requestId": "id", + "routeKey": "$default", + "stage": "$default", + "time": "12/Mar/2020:19:03:58 +0000", + "timeEpoch": 1583348638390 + }, + "pathParameters": { + "parameter1": "value1" + }, + "isBase64Encoded": false, + "stageVariables": { + "stageVariable1": "value1", + "stageVariable2": "value2" + } +} diff --git a/tests/events/apiGatewayProxyV2SchemaMiddlwareInvalidEvent.json b/tests/events/apiGatewayProxyV2SchemaMiddlwareInvalidEvent.json new file mode 100644 index 0000000000..d2f4c404c7 --- /dev/null +++ b/tests/events/apiGatewayProxyV2SchemaMiddlwareInvalidEvent.json @@ -0,0 +1,69 @@ +{ + "version": "2.0", + "routeKey": "$default", + "rawPath": "/my/path", + "rawQueryString": "parameter1=value1¶meter1=value2¶meter2=value", + "cookies": [ + "cookie1", + "cookie2" + ], + "headers": { + "Header1": "value1", + "Header2": "value1,value2" + }, + "queryStringParameters": { + "parameter1": "value1,value2", + "parameter2": "value" + }, + "requestContext": { + "accountId": "123456789012", + "apiId": "api-id", + "authentication": { + "clientCert": { + "clientCertPem": "CERT_CONTENT", + "subjectDN": "www.example.com", + "issuerDN": "Example issuer", + "serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1", + "validity": { + "notBefore": "May 28 12:30:02 2019 GMT", + "notAfter": "Aug 5 09:36:04 2021 GMT" + } + } + }, + "authorizer": { + "jwt": { + "claims": { + "claim1": "value1", + "claim2": "value2" + }, + "scopes": [ + "scope1", + "scope2" + ] + } + }, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "http": { + "method": "POST", + "path": "/my/path", + "protocol": "HTTP/1.1", + "sourceIp": "192.168.0.1/32", + "userAgent": "agent" + }, + "requestId": "id", + "routeKey": "$default", + "stage": "$default", + "time": "12/Mar/2020:19:03:58 +0000", + "timeEpoch": 1583348638390 + }, + "body": "{\"username\": \"lessa\"}", + "pathParameters": { + "parameter1": "value1" + }, + "isBase64Encoded": false, + "stageVariables": { + "stageVariable1": "value1", + "stageVariable2": "value2" + } +} diff --git a/tests/events/apiGatewayProxyV2SchemaMiddlwareValidEvent.json b/tests/events/apiGatewayProxyV2SchemaMiddlwareValidEvent.json new file mode 100644 index 0000000000..7be3d1194d --- /dev/null +++ b/tests/events/apiGatewayProxyV2SchemaMiddlwareValidEvent.json @@ -0,0 +1,69 @@ +{ + "version": "2.0", + "routeKey": "$default", + "rawPath": "/my/path", + "rawQueryString": "parameter1=value1¶meter1=value2¶meter2=value", + "cookies": [ + "cookie1", + "cookie2" + ], + "headers": { + "Header1": "value1", + "Header2": "value1,value2" + }, + "queryStringParameters": { + "parameter1": "value1,value2", + "parameter2": "value" + }, + "requestContext": { + "accountId": "123456789012", + "apiId": "api-id", + "authentication": { + "clientCert": { + "clientCertPem": "CERT_CONTENT", + "subjectDN": "www.example.com", + "issuerDN": "Example issuer", + "serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1", + "validity": { + "notBefore": "May 28 12:30:02 2019 GMT", + "notAfter": "Aug 5 09:36:04 2021 GMT" + } + } + }, + "authorizer": { + "jwt": { + "claims": { + "claim1": "value1", + "claim2": "value2" + }, + "scopes": [ + "scope1", + "scope2" + ] + } + }, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "http": { + "method": "POST", + "path": "/my/path", + "protocol": "HTTP/1.1", + "sourceIp": "192.168.0.1/32", + "userAgent": "agent" + }, + "requestId": "id", + "routeKey": "$default", + "stage": "$default", + "time": "12/Mar/2020:19:03:58 +0000", + "timeEpoch": 1583348638390 + }, + "body": "{\"message\": \"hello world\", \"username\": \"lessa\"}", + "pathParameters": { + "parameter1": "value1" + }, + "isBase64Encoded": false, + "stageVariables": { + "stageVariable1": "value1", + "stageVariable2": "value2" + } +} diff --git a/tests/events/apigatewayeSchemaMiddlwareInvalidEvent.json b/tests/events/apigatewayeSchemaMiddlwareInvalidEvent.json new file mode 100644 index 0000000000..13d810870e --- /dev/null +++ b/tests/events/apigatewayeSchemaMiddlwareInvalidEvent.json @@ -0,0 +1,81 @@ +{ + "version": "1.0", + "resource": "/my/path", + "path": "/my/path", + "httpMethod": "POST", + "headers": { + "Header1": "value1", + "Header2": "value2", + "Origin": "https://aws.amazon.com" + }, + "multiValueHeaders": { + "Header1": [ + "value1" + ], + "Header2": [ + "value1", + "value2" + ] + }, + "queryStringParameters": { + "parameter1": "value1", + "parameter2": "value" + }, + "multiValueQueryStringParameters": { + "parameter1": [ + "value1", + "value2" + ], + "parameter2": [ + "value" + ] + }, + "requestContext": { + "accountId": "123456789012", + "apiId": "id", + "authorizer": { + "claims": null, + "scopes": null + }, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "extendedRequestId": "request-id", + "httpMethod": "GET", + "identity": { + "accessKey": null, + "accountId": null, + "caller": null, + "cognitoAuthenticationProvider": null, + "cognitoAuthenticationType": null, + "cognitoIdentityId": null, + "cognitoIdentityPoolId": null, + "principalOrgId": null, + "sourceIp": "192.168.0.1/32", + "user": null, + "userAgent": "user-agent", + "userArn": null, + "clientCert": { + "clientCertPem": "CERT_CONTENT", + "subjectDN": "www.example.com", + "issuerDN": "Example issuer", + "serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1", + "validity": { + "notBefore": "May 28 12:30:02 2019 GMT", + "notAfter": "Aug 5 09:36:04 2021 GMT" + } + } + }, + "path": "/my/path", + "protocol": "HTTP/1.1", + "requestId": "id=", + "requestTime": "04/Mar/2020:19:15:17 +0000", + "requestTimeEpoch": 1583349317135, + "resourceId": null, + "resourcePath": "/my/path", + "stage": "$default" + }, + "pathParameters": null, + "stageVariables": null, + "body": "{\"username\": \"lessa\"}", + "isBase64Encoded": false +} diff --git a/tests/events/apigatewayeSchemaMiddlwareValidEvent.json b/tests/events/apigatewayeSchemaMiddlwareValidEvent.json new file mode 100644 index 0000000000..454465b9a4 --- /dev/null +++ b/tests/events/apigatewayeSchemaMiddlwareValidEvent.json @@ -0,0 +1,81 @@ +{ + "version": "1.0", + "resource": "/my/path", + "path": "/my/path", + "httpMethod": "POST", + "headers": { + "Header1": "value1", + "Header2": "value2", + "Origin": "https://aws.amazon.com" + }, + "multiValueHeaders": { + "Header1": [ + "value1" + ], + "Header2": [ + "value1", + "value2" + ] + }, + "queryStringParameters": { + "parameter1": "value1", + "parameter2": "value" + }, + "multiValueQueryStringParameters": { + "parameter1": [ + "value1", + "value2" + ], + "parameter2": [ + "value" + ] + }, + "requestContext": { + "accountId": "123456789012", + "apiId": "id", + "authorizer": { + "claims": null, + "scopes": null + }, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "extendedRequestId": "request-id", + "httpMethod": "GET", + "identity": { + "accessKey": null, + "accountId": null, + "caller": null, + "cognitoAuthenticationProvider": null, + "cognitoAuthenticationType": null, + "cognitoIdentityId": null, + "cognitoIdentityPoolId": null, + "principalOrgId": null, + "sourceIp": "192.168.0.1/32", + "user": null, + "userAgent": "user-agent", + "userArn": null, + "clientCert": { + "clientCertPem": "CERT_CONTENT", + "subjectDN": "www.example.com", + "issuerDN": "Example issuer", + "serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1", + "validity": { + "notBefore": "May 28 12:30:02 2019 GMT", + "notAfter": "Aug 5 09:36:04 2021 GMT" + } + } + }, + "path": "/my/path", + "protocol": "HTTP/1.1", + "requestId": "id=", + "requestTime": "04/Mar/2020:19:15:17 +0000", + "requestTimeEpoch": 1583349317135, + "resourceId": null, + "resourcePath": "/my/path", + "stage": "$default" + }, + "pathParameters": null, + "stageVariables": null, + "body": "{\"message\": \"hello world\", \"username\": \"lessa\"}", + "isBase64Encoded": false +} diff --git a/tests/functional/event_handler/conftest.py b/tests/functional/event_handler/conftest.py index 3c281ef0d5..c7a4ac6e50 100644 --- a/tests/functional/event_handler/conftest.py +++ b/tests/functional/event_handler/conftest.py @@ -7,3 +7,35 @@ def json_dump(): # our serializers reduce length to save on costs; fixture to replicate separators return lambda obj: json.dumps(obj, separators=(",", ":")) + + +@pytest.fixture +def validation_schema(): + return { + "$schema": "https://json-schema.org/draft-07/schema", + "$id": "https://example.com/example.json", + "type": "object", + "title": "Sample schema", + "description": "The root schema comprises the entire JSON document.", + "examples": [{"message": "hello world", "username": "lessa"}], + "required": ["message", "username"], + "properties": { + "message": { + "$id": "#/properties/message", + "type": "string", + "title": "The message", + "examples": ["hello world"], + }, + "username": { + "$id": "#/properties/username", + "type": "string", + "title": "The username", + "examples": ["lessa"], + }, + }, + } + + +@pytest.fixture +def raw_event(): + return {"message": "hello hello", "username": "blah blah"} diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 44781193b3..b5f196200a 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -909,7 +909,8 @@ def test_debug_print_event(capsys): # THEN print the event out, err = capsys.readouterr() assert "\n" in out - assert json.loads(out) == event + output: str = out.split("\n")[0] + assert json.loads(output) == event def test_similar_dynamic_routes(): diff --git a/tests/functional/event_handler/test_api_middlewares.py b/tests/functional/event_handler/test_api_middlewares.py new file mode 100644 index 0000000000..8f98b93343 --- /dev/null +++ b/tests/functional/event_handler/test_api_middlewares.py @@ -0,0 +1,480 @@ +from typing import List + +import pytest + +from aws_lambda_powertools.event_handler import content_types +from aws_lambda_powertools.event_handler.api_gateway import ( + APIGatewayHttpResolver, + ApiGatewayResolver, + APIGatewayRestResolver, + ProxyEventType, + Response, + Router, +) +from aws_lambda_powertools.event_handler.exceptions import BadRequestError +from aws_lambda_powertools.event_handler.middlewares import ( + BaseMiddlewareHandler, + NextMiddleware, +) +from aws_lambda_powertools.event_handler.middlewares.schema_validation import ( + SchemaValidationMiddleware, +) +from aws_lambda_powertools.event_handler.types import EventHandlerInstance +from tests.functional.utils import load_event + +API_REST_EVENT = load_event("apiGatewayProxyEvent.json") +API_RESTV2_EVENT = load_event("apiGatewayProxyV2Event_GET.json") + + +@pytest.mark.parametrize( + "app, event", + [ + (ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), API_REST_EVENT), + (APIGatewayRestResolver(), API_REST_EVENT), + (APIGatewayHttpResolver(), API_RESTV2_EVENT), + ], +) +def test_route_with_middleware(app: ApiGatewayResolver, event): + # define custom middleware to inject new argument - "custom" + def middleware_1(app: ApiGatewayResolver, next_middleware: NextMiddleware): + # add additional data to Router Context + app.append_context(custom="custom") + response = next_middleware(app) + + return response + + # define custom middleware to inject new argument - "another_one" + def middleware_2(app: ApiGatewayResolver, next_middleware: NextMiddleware): + # add additional data to Router Context + app.append_context(another_one=6) + response = next_middleware(app) + + return response + + @app.get("/my/path", middlewares=[middleware_1, middleware_2]) + def get_lambda() -> Response: + another_one = app.context.get("another_one") + custom = app.context.get("custom") + assert another_one == 6 + assert custom == "custom" + + return Response(200, content_types.TEXT_HTML, "foo") + + # WHEN calling the event handler + result = app(event, {}) + + # THEN process event correctly + # AND set the current_event type as APIGatewayProxyEvent + assert result["statusCode"] == 200 + assert result["body"] == "foo" + + +@pytest.mark.parametrize( + "app, event, other_event", + [ + ( + ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), + API_REST_EVENT, + load_event("apiGatewayProxyOtherEvent.json"), + ), + ( + APIGatewayRestResolver(), + API_REST_EVENT, + load_event("apiGatewayProxyOtherEvent.json"), + ), + ( + APIGatewayHttpResolver(), + API_RESTV2_EVENT, + load_event("apiGatewayProxyV2OtherGetEvent.json"), + ), + ], +) +def test_with_router_middleware(app: ApiGatewayResolver, event, other_event): + # define custom middleware to inject new argument - "custom" + def global_middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware): + # add custom data to context + app.append_context(custom="custom") + response = next_middleware(app) + + return response + + # define custom middleware to inject new argument - "another_one" + def middleware_2(app: ApiGatewayResolver, next_middleware: NextMiddleware): + # add data to resolver context + app.append_context(another_one=6) + response = next_middleware(app) + + return response + + app.use([global_middleware]) + + @app.get("/my/path", middlewares=[middleware_2]) + def get_lambda() -> Response: + another_one: int = app.context.get("another_one") + custom: str = app.context.get("custom") + assert another_one == 6 + assert custom == "custom" + + return Response(200, content_types.TEXT_HTML, "foo") + + # WHEN calling the event handler + result = app(event, {}) + + # THEN process event correctly + # AND set the current_event type as APIGatewayProxyEvent + assert result["statusCode"] == 200 + assert result["body"] == "foo" + + @app.get("/other/path") + def get_other_lambda() -> Response: + custom: str = app.context.get("custom") + assert custom == "custom" + + return Response(200, content_types.TEXT_HTML, "other_foo") + + # WHEN calling the event handler + result = app(other_event, {}) + + # THEN process event correctly + # AND set the current_event type as APIGatewayProxyEvent + assert result["statusCode"] == 200 + assert result["body"] == "other_foo" + + +@pytest.mark.parametrize( + "app, event", + [ + (ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), API_REST_EVENT), + (APIGatewayRestResolver(), API_REST_EVENT), + (APIGatewayHttpResolver(), API_RESTV2_EVENT), + ], +) +def test_dynamic_route_with_middleware(app: ApiGatewayResolver, event): + def middleware_one(app: ApiGatewayResolver, next_middleware: NextMiddleware): + # inject data into the resolver context + app.append_context(injected="injected_value") + response = next_middleware(app) + + return response + + @app.get("//", middlewares=[middleware_one]) + def get_lambda(my_id: str, name: str) -> Response: + injected: str = app.context.get("injected") + assert name == "my" + assert injected == "injected_value" + + return Response(200, content_types.TEXT_HTML, my_id) + + # WHEN calling the event handler + result = app(event, {}) + + # THEN + assert result["statusCode"] == 200 + assert result["body"] == "path" + + +@pytest.mark.parametrize( + "app, event", + [ + (ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), API_REST_EVENT), + (APIGatewayRestResolver(), API_REST_EVENT), + (APIGatewayHttpResolver(), API_RESTV2_EVENT), + ], +) +def test_middleware_early_return(app: ApiGatewayResolver, event): + def middleware_one(app: ApiGatewayResolver, next_middleware): + # inject a variable into resolver context + app.append_context(injected="injected_value") + response = next_middleware(app) + + return response + + def early_return_middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware): + assert app.context.get("injected") == "injected_value" + + return Response(400, content_types.TEXT_HTML, "bad_response") + + def not_executed_middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware): + # This should never be executed - if it is an excpetion will be raised + raise NotImplementedError() + + @app.get("//", middlewares=[middleware_one, early_return_middleware, not_executed_middleware]) + def get_lambda(my_id: str, name: str) -> Response: + assert name == "my" + assert app.context.get("injected") == "injected_value" + + return Response(200, content_types.TEXT_HTML, my_id) + + # WHEN calling the event handler + result = app(event, {}) + + # THEN + assert result["statusCode"] == 400 + assert result["body"] == "bad_response" + + +@pytest.mark.parametrize( + "app, event", + [ + ( + ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), + load_event("apigatewayeSchemaMiddlwareValidEvent.json"), + ), + ( + APIGatewayRestResolver(), + load_event("apigatewayeSchemaMiddlwareValidEvent.json"), + ), + ( + APIGatewayHttpResolver(), + load_event("apiGatewayProxyV2SchemaMiddlwareValidEvent.json"), + ), + ], +) +def test_pass_schema_validation(app: ApiGatewayResolver, event, validation_schema): + @app.post("/my/path", middlewares=[SchemaValidationMiddleware(validation_schema)]) + def post_lambda() -> Response: + return Response(200, content_types.TEXT_HTML, "path") + + # WHEN calling the event handler + result = app(event, {}) + + # THEN + assert result["statusCode"] == 200 + assert result["body"] == "path" + + +@pytest.mark.parametrize( + "app, event", + [ + ( + ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), + load_event("apigatewayeSchemaMiddlwareInvalidEvent.json"), + ), + ( + APIGatewayRestResolver(), + load_event("apigatewayeSchemaMiddlwareInvalidEvent.json"), + ), + ( + APIGatewayHttpResolver(), + load_event("apiGatewayProxyV2SchemaMiddlwareInvalidEvent.json"), + ), + ], +) +def test_fail_schema_validation(app: ApiGatewayResolver, event, validation_schema): + @app.post("/my/path", middlewares=[SchemaValidationMiddleware(validation_schema)]) + def post_lambda() -> Response: + return Response(200, content_types.TEXT_HTML, "Should not be returned") + + # WHEN calling the event handler + result = app(event, {}) + print(f"\nRESULT:::{result}") + + # THEN + assert result["statusCode"] == 400 + assert ( + result["body"] + == "{\"statusCode\":400,\"message\":\"Bad Request: Failed schema validation. Error: data must contain ['message'] properties, Path: ['data'], Data: {'username': 'lessa'}\"}" # noqa: E501 + ) + + +@pytest.mark.parametrize( + "app, event", + [ + ( + ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), + load_event("apigatewayeSchemaMiddlwareValidEvent.json"), + ), + ( + APIGatewayRestResolver(), + load_event("apigatewayeSchemaMiddlwareValidEvent.json"), + ), + ( + APIGatewayHttpResolver(), + load_event("apiGatewayProxyV2SchemaMiddlwareInvalidEvent.json"), + ), + ], +) +def test_invalid_schema_validation(app: ApiGatewayResolver, event): + @app.post("/my/path", middlewares=[SchemaValidationMiddleware(inbound_schema="schema.json")]) + def post_lambda() -> Response: + return Response(200, content_types.TEXT_HTML, "Should not be returned") + + # WHEN calling the event handler + result = app(event, {}) + + print(f"\nRESULT:::{result}") + # THEN + assert result["statusCode"] == 500 + assert result["body"] == '{"statusCode":500,"message":"Internal Server Error"}' + + +@pytest.mark.parametrize( + "app, event", + [ + (ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), API_REST_EVENT), + (APIGatewayRestResolver(), API_REST_EVENT), + (APIGatewayHttpResolver(), API_RESTV2_EVENT), + ], +) +def test_middleware_short_circuit_via_httperrors(app: ApiGatewayResolver, event): + def middleware_one(app: ApiGatewayResolver, next_middleware: NextMiddleware): + # inject a variable into the kwargs of the middleware chain + app.append_context(injected="injected_value") + response = next_middleware(app) + + return response + + def early_return_middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware): + # ensure "injected" context variable is passed in by middleware_one + assert app.context.get("injected") == "injected_value" + raise BadRequestError("bad_response") + + def not_executed_middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware): + # This should never be executed - if it is an excpetion will be raised + raise NotImplementedError() + + @app.get("//", middlewares=[middleware_one, early_return_middleware, not_executed_middleware]) + def get_lambda(my_id: str, name: str) -> Response: + assert name == "my" + assert app.context.get("injected") == "injected_value" + + return Response(200, content_types.TEXT_HTML, my_id) + + # WHEN calling the event handler + result = app(event, {}) + + # THEN + assert result["statusCode"] == 400 + assert result["body"] == '{"statusCode":400,"message":"bad_response"}' + + +@pytest.mark.parametrize( + "app, event", + [ + (ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), API_REST_EVENT), + (APIGatewayRestResolver(), API_REST_EVENT), + (APIGatewayHttpResolver(), API_RESTV2_EVENT), + ], +) +def test_api_gateway_middleware_order_with_include_router_last(app: EventHandlerInstance, event): + # GIVEN two global middlewares: one for App and one for Router + router = Router() + + def global_app_middleware(app: EventHandlerInstance, next_middleware: NextMiddleware): + middleware_order: List[str] = router.context.get("middleware_order", []) + middleware_order.append("app") + + app.append_context(middleware_order=middleware_order) + return next_middleware(app) + + def global_router_middleware(router: EventHandlerInstance, next_middleware: NextMiddleware): + middleware_order: List[str] = router.context.get("middleware_order", []) + middleware_order.append("router") + + router.append_context(middleware_order=middleware_order) + return next_middleware(app) + + @router.get("/my/path") + def dummy_route(): + middleware_order = app.context["middleware_order"] + + assert middleware_order[0] == "app" + assert middleware_order[1] == "router" + + return Response(status_code=200, body="works!") + + # WHEN App global middlewares are registered first + # followed by include_router + + router.use([global_router_middleware]) # mimics App importing Router + app.use([global_app_middleware]) + app.include_router(router) + + # THEN resolving a request should start processing global Router middlewares first + # due to insertion order + result = app(event, {}) + + assert result["statusCode"] == 200 + + +@pytest.mark.parametrize( + "app, event", + [ + (ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), API_REST_EVENT), + (APIGatewayRestResolver(), API_REST_EVENT), + (APIGatewayHttpResolver(), API_RESTV2_EVENT), + ], +) +def test_api_gateway_middleware_order_with_include_router_first(app: EventHandlerInstance, event): + # GIVEN two global middlewares: one for App and one for Router + router = Router() + + def global_app_middleware(app: EventHandlerInstance, next_middleware: NextMiddleware): + middleware_order: List[str] = router.context.get("middleware_order", []) + middleware_order.append("app") + + app.append_context(middleware_order=middleware_order) + return next_middleware(app) + + def global_router_middleware(router: EventHandlerInstance, next_middleware: NextMiddleware): + middleware_order: List[str] = router.context.get("middleware_order", []) + middleware_order.append("router") + + router.append_context(middleware_order=middleware_order) + return next_middleware(app) + + @router.get("/my/path") + def dummy_route(): + middleware_order = app.context["middleware_order"] + + assert middleware_order[0] == "router" + assert middleware_order[1] == "app" + + return Response(status_code=200, body="works!") + + # WHEN App include router middlewares first + # followed by App global middlewares registration + + router.use([global_router_middleware]) # mimics App importing Router + app.include_router(router) + + app.use([global_app_middleware]) + + # THEN resolving a request should start processing global Router middlewares first + # due to insertion order + result = app(event, {}) + + assert result["statusCode"] == 200 + + +def test_class_based_middleware(): + # GIVEN a class-based middleware implementing BaseMiddlewareHandler correctly + class CorrelationIdMiddleware(BaseMiddlewareHandler): + def __init__(self, header: str): + super().__init__() + self.header = header + + def handler(self, app: ApiGatewayResolver, get_response: NextMiddleware, **kwargs) -> Response: + request_id = app.current_event.request_context.request_id # type: ignore[attr-defined] # using REST event in a base Resolver # noqa: E501 + correlation_id = app.current_event.get_header_value( + name=self.header, + default_value=request_id, + ) # noqa: E501 + + response = get_response(app, **kwargs) + response.headers[self.header] = correlation_id + + return response + + resolver = ApiGatewayResolver() + event = load_event("apiGatewayProxyEvent.json") + + # WHEN instantiated with extra configuration as part of a route handler + @resolver.get("/my/path", middlewares=[CorrelationIdMiddleware(header="X-Correlation-Id")]) + def post_lambda(): + return {"hello": "world"} + + # THEN it should work as any other middleware when a request is processed + result = resolver(event, {}) + assert result["statusCode"] == 200 + assert result["multiValueHeaders"]["X-Correlation-Id"][0] == resolver.current_event.request_context.request_id # type: ignore[attr-defined] # noqa: E501