From 9f593742d9e4a38aa051444fd1f7263ad98886fe Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 22 Nov 2023 21:39:17 +0100 Subject: [PATCH] fix(event_handler): fix format for OpenAPI path templating (#3399) --- aws_lambda_powertools/event_handler/api_gateway.py | 13 +++++++++---- .../event_handler/openapi/dependant.py | 4 ++-- .../functional/event_handler/test_openapi_params.py | 6 +++--- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 05831a2eea..4263a5132a 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -316,6 +316,11 @@ def __init__( """ self.method = method.upper() self.path = "/" if path.strip() == "" else path + + # OpenAPI spec only understands paths with { }. So we'll have to convert Powertools' < >. + # https://swagger.io/specification/#path-templating + self.openapi_path = re.sub(r"<(.*?)>", lambda m: f"{{{''.join(m.group(1))}}}", self.path) + self.rule = rule self.func = func self._middleware_stack = func @@ -435,7 +440,7 @@ def dependant(self) -> "Dependant": if self._dependant is None: from aws_lambda_powertools.event_handler.openapi.dependant import get_dependant - self._dependant = get_dependant(path=self.path, call=self.func) + self._dependant = get_dependant(path=self.openapi_path, call=self.func) return self._dependant @@ -542,7 +547,7 @@ def _openapi_operation_summary(self) -> str: Returns the OpenAPI operation summary. If the user has not provided a summary, we generate one based on the route path and method. """ - return self.summary or f"{self.method.upper()} {self.path}" + return self.summary or f"{self.method.upper()} {self.openapi_path}" def _openapi_operation_metadata(self, operation_ids: Set[str]) -> Dict[str, Any]: """ @@ -692,7 +697,7 @@ def _openapi_operation_return( return {"schema": return_schema} def _generate_operation_id(self) -> str: - operation_id = self.func.__name__ + self.path + operation_id = self.func.__name__ + self.openapi_path operation_id = re.sub(r"\W", "_", operation_id) operation_id = operation_id + "_" + self.method.lower() return operation_id @@ -1452,7 +1457,7 @@ def get_openapi_schema( if result: path, path_definitions = result if path: - paths.setdefault(route.path, {}).update(path) + paths.setdefault(route.openapi_path, {}).update(path) if path_definitions: definitions.update(path_definitions) diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index 87e0c7dfb3..e22eb535a7 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -124,7 +124,7 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: def get_path_param_names(path: str) -> Set[str]: """ - Returns the path parameter names from a path template. Those are the strings between < and >. + Returns the path parameter names from a path template. Those are the strings between { and }. Parameters ---------- @@ -137,7 +137,7 @@ def get_path_param_names(path: str) -> Set[str]: The path parameter names """ - return set(re.findall("<(.*?)>", path)) + return set(re.findall("{(.*?)}", path)) def get_dependant( diff --git a/tests/functional/event_handler/test_openapi_params.py b/tests/functional/event_handler/test_openapi_params.py index 6e4f0395af..9209cb9dec 100644 --- a/tests/functional/event_handler/test_openapi_params.py +++ b/tests/functional/event_handler/test_openapi_params.py @@ -70,13 +70,13 @@ def handler(user_id: str, include_extra: bool = False): assert schema.info.version == "0.2.2" assert len(schema.paths.keys()) == 1 - assert "/users/" in schema.paths + assert "/users/{user_id}" in schema.paths - path = schema.paths["/users/"] + path = schema.paths["/users/{user_id}"] assert path.get get = path.get - assert get.summary == "GET /users/" + assert get.summary == "GET /users/{user_id}" assert get.operationId == "handler_users__user_id__get" assert len(get.parameters) == 2