From aa68e4087a3df8d20e2ebe72e9405e5d2f1d39ea Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Thu, 7 Sep 2023 21:58:52 +0200 Subject: [PATCH] chore: last cleanups Signed-off-by: heitorlessa --- .../event_handler/api_gateway.py | 78 ++++++++++-------- .../event_handler/middlewares/base.py | 80 +++++++++---------- .../src/all_routes_middleware.py | 31 ------- .../src/custom_middlewares.py | 31 ------- .../src/route_middleware.py | 29 ------- 5 files changed, 83 insertions(+), 166 deletions(-) delete mode 100644 examples/event_handler_rest/src/all_routes_middleware.py delete mode 100644 examples/event_handler_rest/src/custom_middlewares.py delete mode 100644 examples/event_handler_rest/src/route_middleware.py diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 781bc3ef5b0..2163d7d762e 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -206,7 +206,7 @@ def __init__( cors: bool, compress: bool, cache_control: Optional[str], - middlewares: Optional[List[Callable[..., Any]]], + middlewares: Optional[List[Callable[..., Response]]], ): """ @@ -225,9 +225,8 @@ def __init__( Whether or not to enable gzip compression for this route cache_control: Optional[str] The cache control header value, example "max-age=3600" - middlewares: Optional[List[Callable[..., Any]]] - The list of route middlewares. These are called in the order they are - provided. + middlewares: Optional[List[Callable[..., Response]]] + The list of route middlewares to be called in order. """ self.method = method.upper() self.rule = rule @@ -238,9 +237,7 @@ def __init__( self.cache_control = cache_control self.middlewares = middlewares or [] - """ - _middleware_stack_built is used to ensure the middleware stack is only built once. - """ + # _middleware_stack_built is used to ensure the middleware stack is only built once. self._middleware_stack_built = False def __call__( @@ -258,7 +255,7 @@ def __call__( ---------- router_middlewares: List[Callable] The list of Router Middlewares (assigned to ALL routes) - app: Callable + app: "ApiGatewayResolver" The ApiGatewayResolver instance to pass into the middleware stack route_arguments: Dict[str, str] The route arguments to pass to the app function (extracted from the Api Gateway @@ -267,13 +264,11 @@ def __call__( Returns ------- Union[Dict, Tuple, Response] - Returns an API Response object in ALL cases, excepting when the original API route - handler is called which may also return a Dict or Tuple response. + API Response object in ALL cases, except when the original API route + handler is called which may also return a Dict, Tuple, or Response. """ - # Check self._middleware_stack_built to ensure the middleware stack is only built once. - # This will save CPU time when an API route is processed multiple times. - # + # Save CPU cycles by building middleware stack once if not self._middleware_stack_built: self._build_middleware_stack(router_middlewares=router_middlewares) @@ -312,14 +307,16 @@ def _build_middleware_stack(self, router_middlewares: List[Callable[..., Any]]) middleware handlers is applied in the order of being added to the handler. """ all_middlewares = router_middlewares + self.middlewares + logger.debug(f"Building middleware stack: {all_middlewares}") # IMPORTANT: # this must be the last middleware in the stack (tech debt for backward - # compatability purposes) + # compatibility purposes) # - # This adapter will call the registered API passing only the expected route arguments extracted from the path + # This adapter will: + # 1. Call the registered API passing only the expected route arguments extracted from the path # and not the middleware. - # This adapter will adapt the response type of the route handler (Union[Dict, Tuple, Response]) + # 2. Adapt the response type of the route handler (Union[Dict, Tuple, Response]) # and normalise into a Response object so middleware will always have a constant signature all_middlewares.append(_registered_api_adapter) @@ -450,31 +447,42 @@ def route( def use(self, middlewares: List[Callable[..., Response]]) -> None: """ - Add a list of middlewares to the global router middleware list + Add one or more global middlewares that run before/after route specific middleware. - These middlewares will be called in insertion order and - before any middleware registered at the route level. + NOTE: Middlewares are called in insertion order. + + Parameters + ---------- + middlewares: List[Callable[..., Response]] + List of global middlewares to be used + + Examples + -------- - Example - ------- Add middlewares to be used for every request processed by the Router. - ``` - my_custom_middleware = new CustomMiddleware() + ```python + from aws_lambda_powertools import Logger + from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response + from aws_lambda_powertools.event_handler.middlewares import NextMiddleware - app.use([my_custom_middleware]) + logger = Logger() + app = APIGatewayRestResolver() - ``` + def log_request_response(app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response: + logger.info("Incoming request", path=app.current_event.path, request=app.current_event.raw_event) - Parameters - ---------- - middlewares - List of middlewares to be used + result = next_middleware(app) + logger.info("Response received", response=result.__dict__) - Returns - ------- - None + return result + app.use(middlewares=[log_request_response]) + + def lambda_handler(event, context): + return app.resolve(event, context) + ``` """ self._router_middlewares = self._router_middlewares + middlewares @@ -646,14 +654,14 @@ def lambda_handler(event, context): def _push_processed_stack_frame(self, frame: str): """ Add Current Middleware to the Middleware Stack Frames - The stack frames will be used when excpetions are thrown and Powertools + The stack frames will be used when exceptions are thrown and Powertools debug is enabled by developers. """ self.processed_stack_frames.append(frame) def _reset_processed_stack(self): """Reset the Processed Stack Frames""" - self.processed_stack_frames = [] + self.processed_stack_frames.clear() def append_context(self, **additional_context): """Append key=value data as routing context""" @@ -832,6 +840,7 @@ def __init__( self._debug = self._has_debug(debug) self._strip_prefixes = strip_prefixes self.context: Dict = {} # early init as customers might add context before event resolution + self.processed_stack_frames = [] # Allow for a custom serializer or a concise json serialization self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder) @@ -1005,8 +1014,7 @@ def _resolve(self) -> ResponseBuilder: if match_results: logger.debug("Found a registered route. Calling function") # Add matched Route reference into the Resolver context - self.append_context(_route=route) - self.append_context(_path=path) + self.append_context(_route=route, _path=path) return self._call_route(route, match_results.groupdict()) # pass fn args diff --git a/aws_lambda_powertools/event_handler/middlewares/base.py b/aws_lambda_powertools/event_handler/middlewares/base.py index ce22f736520..32a4486bb31 100644 --- a/aws_lambda_powertools/event_handler/middlewares/base.py +++ b/aws_lambda_powertools/event_handler/middlewares/base.py @@ -18,64 +18,64 @@ def __name__(self) -> str: # noqa A003 class BaseMiddlewareHandler(Generic[EventHandlerInstance], ABC): - """ - Base class for Middleware Handlers + """Base implementation for Middlewares to run code before and after in a chain. + This is the middleware handler function where middleware logic is implemented. - Here you have the option to execute code before and after the next handler in the - middleware chain is called. The next middleware handler is represented by `next_middleware`. + The next middleware handler is represented by `next_middleware`, returning a Response object. + + Examples + -------- + **Correlation ID Middleware** ```python + import requests - # Place code here for actions BEFORE the next middleware handler is called - # or optionally raise an exception to short-circuit the middleware execution chain + from aws_lambda_powertools import Logger + from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response + from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware - # Get the response from the NEXT middleware handler (optionally injecting custom - # arguments into the next_middleware call) - result: Response = next_middleware(app, my_custom_arg="handled") + app = APIGatewayRestResolver() + logger = Logger() - # Place code here for actions AFTER the next middleware handler is called - return result - ``` + class CorrelationIdMiddleware(BaseMiddlewareHandler): + def __init__(self, header: str): + super().__init__() + self.header = header - To implement ERROR style middleware wrap the call to `next_middleware` in a `try..except` - block - you can also catch specific types of errors this way so your middleware only handles - specific types of exceptions. + def handler(self, app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response: + # BEFORE logic + request_id = app.current_event.request_context.request_id + correlation_id = app.current_event.get_header_value( + name=self.header, + default_value=request_id, + ) - for example: + # Call next middleware or route handler ('/todos') + response = next_middleware(app) - ```python + # AFTER logic + response.headers[self.header] = correlation_id - try: - result: Response = next_middleware(app, my_custom_arg="handled") - except MyCustomValidationException as e: - # Make sure we send back a 400 response for any Custom Validation Exceptions. - result.status_code = 400 - result.body = {"message": "Failed validation"} - logger.exception(f"Failed validation when handling route: {app.current_event.path}") + return response - return result - ``` - To short-circuit the middleware execution chain you can either raise an exception to cause - the function call stack to unwind naturally OR you can simple not call the `next_middleware` - handler to get the response from the next middleware handler in the chain. + @app.get("/todos", middlewares=[CorrelationIdMiddleware(header="x-correlation-id")]) + def get_todos(): + todos: requests.Response = requests.get("https://jsonplaceholder.typicode.com/todos") + todos.raise_for_status() - for example: - If you wanted to ensure API callers cannot call a DELETE verb on your API (regardless of defined routes) - you could do so with the following middleware implementation. + # for brevity, we'll limit to the first 10 only + return {"todos": todos.json()[:10]} - ```python - # If invalid http_method is used - return a 405 error - # and return early to short-circuit the middleware execution chain - if app.current_event.http_method == "DELETE": - return Response(status_code=405, body={"message": "DELETE verb not allowed"}) + @logger.inject_lambda_context + def lambda_handler(event, context): + return app.resolve(event, context) - # Call the next middleware in the chain (needed for when condition above is valid) - return next_middleware(app) + ``` """ @@ -111,7 +111,7 @@ def __call__(self, app: EventHandlerInstance, next_middleware: NextMiddleware) - ---------- app: ApiGatewayResolver An instance of an Event Handler that implements ApiGatewayResolver - next_middleware: Callable[..., Any] + next_middleware: NextMiddleware The next middleware handler in the chain Returns diff --git a/examples/event_handler_rest/src/all_routes_middleware.py b/examples/event_handler_rest/src/all_routes_middleware.py deleted file mode 100644 index b1e3a6a7e5a..00000000000 --- a/examples/event_handler_rest/src/all_routes_middleware.py +++ /dev/null @@ -1,31 +0,0 @@ -import requests -from custom_middlewares import sanitise_exceptions, validate_correlation_id -from requests import Response - -from aws_lambda_powertools import Logger, Tracer -from aws_lambda_powertools.event_handler import APIGatewayRestResolver -from aws_lambda_powertools.logging import correlation_paths -from aws_lambda_powertools.utilities.typing import LambdaContext - -tracer = Tracer() -logger = Logger() -app = APIGatewayRestResolver() - -app.use(middlewares=[validate_correlation_id, sanitise_exceptions]) - - -@app.get("/todos") -@tracer.capture_method -def get_todos(): - todos: Response = requests.get("https://jsonplaceholder.typicode.com/todos") - todos.raise_for_status() - - # for brevity, we'll limit to the first 10 only - return {"todos": todos.json()[:10]} - - -# You can continue to use other utilities just as before -@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_REST) -@tracer.capture_lambda_handler -def lambda_handler(event: dict, context: LambdaContext) -> dict: - return app.resolve(event, context) diff --git a/examples/event_handler_rest/src/custom_middlewares.py b/examples/event_handler_rest/src/custom_middlewares.py deleted file mode 100644 index 7f6e22f16a9..00000000000 --- a/examples/event_handler_rest/src/custom_middlewares.py +++ /dev/null @@ -1,31 +0,0 @@ -from aws_lambda_powertools import Logger -from aws_lambda_powertools.event_handler import ApiGatewayResolver, Response -from aws_lambda_powertools.event_handler.exceptions import BadRequestError, InternalServerError, ServiceError -from aws_lambda_powertools.event_handler.middlewares import NextMiddleware - -logger = Logger() - - -def validate_correlation_id(app: ApiGatewayResolver, next_middleware: NextMiddleware) -> Response: - # If missing mandatory header raise an error - if not app.current_event.headers.get("x-correlation-id", None): - raise BadRequestError("No [x-correlation-id] header provided. All requests must include this header.") - - # Get the response from the next middleware and return it - return next_middleware(app) - - -def sanitise_exceptions(app: ApiGatewayResolver, next_middleware: NextMiddleware) -> Response: - try: - # Get the Result from the next middleware - result = next_middleware(app) - except Exception as err: - logger.exception(err) - # Raise a clean error for ALL unexpected exceptions (ServiceError based Exceptions are okay) - if not isinstance(err, ServiceError): - raise InternalServerError("An error occurred during processing, please contact your administrator") from err - - raise err - - # return the result when there are no exceptions - return result diff --git a/examples/event_handler_rest/src/route_middleware.py b/examples/event_handler_rest/src/route_middleware.py deleted file mode 100644 index a135971c164..00000000000 --- a/examples/event_handler_rest/src/route_middleware.py +++ /dev/null @@ -1,29 +0,0 @@ -import requests -from custom_middlewares import sanitise_exceptions, validate_correlation_id -from requests import Response - -from aws_lambda_powertools import Logger, Tracer -from aws_lambda_powertools.event_handler import APIGatewayRestResolver -from aws_lambda_powertools.logging import correlation_paths -from aws_lambda_powertools.utilities.typing import LambdaContext - -tracer = Tracer() -logger = Logger() -app = APIGatewayRestResolver() - - -@app.get("/todos", middlewares=[validate_correlation_id, sanitise_exceptions]) -@tracer.capture_method -def get_todos(): - todos: Response = requests.get("https://jsonplaceholder.typicode.com/todos") - todos.raise_for_status() - - # for brevity, we'll limit to the first 10 only - return {"todos": todos.json()[:10]} - - -# You can continue to use other utilities just as before -@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_REST) -@tracer.capture_lambda_handler -def lambda_handler(event: dict, context: LambdaContext) -> dict: - return app.resolve(event, context)