diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 1383b74ada0..e36fa7e8740 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -9,6 +9,7 @@ from enum import Enum from functools import partial from http import HTTPStatus +from pathlib import Path from typing import ( TYPE_CHECKING, Any, @@ -236,6 +237,15 @@ def __init__( if content_type: self.headers.setdefault("Content-Type", content_type) + def is_json(self) -> bool: + """ + Returns True if the response is JSON, based on the Content-Type. + """ + content_type = self.headers.get("Content-Type", "") + if isinstance(content_type, list): + content_type = content_type[0] + return content_type.startswith("application/json") + class Route: """Internally used Route Configuration""" @@ -255,6 +265,7 @@ def __init__( response_description: Optional[str], tags: Optional[List["Tag"]], operation_id: Optional[str], + include_in_schema: bool, middlewares: Optional[List[Callable[..., Response]]], ): """ @@ -288,6 +299,8 @@ def __init__( The list of OpenAPI tags to be used for this route operation_id: Optional[str] The OpenAPI operationId for this route + include_in_schema: bool + Whether or not to include this route in the OpenAPI schema middlewares: Optional[List[Callable[..., Response]]] The list of route middlewares to be called in order. """ @@ -304,6 +317,7 @@ def __init__( self.responses = responses self.response_description = response_description self.tags = tags or [] + self.include_in_schema = include_in_schema self.middlewares = middlewares or [] self.operation_id = operation_id or self._generate_operation_id() @@ -483,7 +497,6 @@ def _get_openapi_path( # Add the response schema to the OpenAPI 200 response json_response.update( self._openapi_operation_return( - operation_id=self.operation_id, param=dependant.return_param, model_name_map=model_name_map, field_mapping=field_mapping, @@ -643,7 +656,6 @@ def _openapi_operation_parameters( @staticmethod def _openapi_operation_return( *, - operation_id: str, param: Optional["ModelField"], model_name_map: Dict["TypeModelOrEnum", str], field_mapping: Dict[ @@ -667,7 +679,7 @@ def _openapi_operation_return( field_mapping=field_mapping, ) - return {"name": f"Return {operation_id}", "schema": return_schema} + return {"schema": return_schema} def _generate_operation_id(self) -> str: operation_id = self.func.__name__ + self.path @@ -792,6 +804,7 @@ def route( response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, + include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): raise NotImplementedError() @@ -849,6 +862,7 @@ def get( response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, + include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Get route decorator with GET `method` @@ -885,6 +899,7 @@ def lambda_handler(event, context): response_description, tags, operation_id, + include_in_schema, middlewares, ) @@ -900,6 +915,7 @@ def post( response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, + include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Post route decorator with POST `method` @@ -937,6 +953,7 @@ def lambda_handler(event, context): response_description, tags, operation_id, + include_in_schema, middlewares, ) @@ -952,6 +969,7 @@ def put( response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, + include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Put route decorator with PUT `method` @@ -989,6 +1007,7 @@ def lambda_handler(event, context): response_description, tags, operation_id, + include_in_schema, middlewares, ) @@ -1004,6 +1023,7 @@ def delete( response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, + include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Delete route decorator with DELETE `method` @@ -1040,6 +1060,7 @@ def lambda_handler(event, context): response_description, tags, operation_id, + include_in_schema, middlewares, ) @@ -1055,6 +1076,7 @@ def patch( response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List["Tag"]] = None, operation_id: Optional[str] = None, + include_in_schema: bool = True, middlewares: Optional[List[Callable]] = None, ): """Patch route decorator with PATCH `method` @@ -1094,6 +1116,7 @@ def lambda_handler(event, context): response_description, tags, operation_id, + include_in_schema, middlewares, ) @@ -1345,7 +1368,7 @@ def get_openapi_schema( A URL to the Terms of Service for the API. MUST be in the format of a URL. contact: Contact, optional The contact information for the exposed API. - license_info: + license_info: License, optional The license information for the exposed API. Returns @@ -1403,6 +1426,9 @@ def get_openapi_schema( # Add routes to the OpenAPI schema for route in all_routes: + if not route.include_in_schema: + continue + result = route._get_openapi_path( dependant=route.dependant, operation_ids=operation_ids, @@ -1464,7 +1490,7 @@ def get_openapi_json_schema( A URL to the Terms of Service for the API. MUST be in the format of a URL. contact: Contact, optional The contact information for the exposed API. - license_info: + license_info: License, optional The license information for the exposed API. Returns @@ -1492,6 +1518,152 @@ def get_openapi_json_schema( indent=2, ) + def enable_swagger( + self, + *, + path: str = "/swagger", + title: str = "Powertools API", + version: str = "1.0.0", + openapi_version: str = "3.1.0", + summary: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[List["Tag"]] = None, + servers: Optional[List["Server"]] = None, + terms_of_service: Optional[str] = None, + contact: Optional["Contact"] = None, + license_info: Optional["License"] = None, + swagger_base_url: Optional[str] = None, + middlewares: Optional[List[Callable[..., Response]]] = None, + ): + """ + Returns the OpenAPI schema as a JSON serializable dict + + Parameters + ---------- + path: str, default = "/swagger" + The path to the swagger UI. + title: str + The title of the application. + version: str + The version of the OpenAPI document (which is distinct from the OpenAPI Specification version or the API + openapi_version: str, default = "3.1.0" + The version of the OpenAPI Specification (which the document uses). + summary: str, optional + A short summary of what the application does. + description: str, optional + A verbose explanation of the application behavior. + tags: List[Tag], optional + A list of tags used by the specification with additional metadata. + servers: List[Server], optional + An array of Server Objects, which provide connectivity information to a target server. + terms_of_service: str, optional + A URL to the Terms of Service for the API. MUST be in the format of a URL. + contact: Contact, optional + The contact information for the exposed API. + license_info: License, optional + The license information for the exposed API. + swagger_base_url: str, optional + The base url for the swagger UI. If not provided, we will serve a recent version of the Swagger UI. + middlewares: List[Callable[..., Response]], optional + List of middlewares to be used for the swagger route. + """ + from aws_lambda_powertools.event_handler.openapi.models import Server + + if not swagger_base_url: + + @self.get("/swagger.js", include_in_schema=False) + def swagger_js(): + body = Path.open(Path(__file__).parent / "openapi" / "swagger_ui" / "swagger-ui-bundle.min.js").read() + return Response( + status_code=200, + content_type="text/javascript", + body=body, + compress=True, + ) + + @self.get("/swagger.css", include_in_schema=False) + def swagger_css(): + body = Path.open(Path(__file__).parent / "openapi" / "swagger_ui" / "swagger-ui.min.css").read() + return Response( + status_code=200, + content_type="text/css", + body=body, + compress=True, + ) + + @self.get(path, middlewares=middlewares, include_in_schema=False) + def swagger_handler(): + base_path = self._get_base_path() + + if swagger_base_url: + swagger_js = f"{swagger_base_url}/swagger-ui-bundle.js" + swagger_css = f"{swagger_base_url}/swagger-ui.min.css" + else: + swagger_js = f"{base_path}/swagger.js" + swagger_css = f"{base_path}/swagger.css" + + openapi_servers = servers or [Server(url=base_path)] + + spec = self.get_openapi_json_schema( + title=title, + version=version, + openapi_version=openapi_version, + summary=summary, + description=description, + tags=tags, + servers=openapi_servers, + terms_of_service=terms_of_service, + contact=contact, + license_info=license_info, + ) + + body = f""" + + +
+ +