diff --git a/aws_lambda_powertools/utilities/parser/functions.py b/aws_lambda_powertools/utilities/parser/functions.py index b9a35176a1e..351e214da93 100644 --- a/aws_lambda_powertools/utilities/parser/functions.py +++ b/aws_lambda_powertools/utilities/parser/functions.py @@ -35,12 +35,17 @@ def _retrieve_or_set_model_from_cache(model: type[T]) -> TypeAdapter: The TypeAdapter instance for the given model, either retrieved from the cache or newly created and stored in the cache. """ + id_model = id(model) if id_model in CACHE_TYPE_ADAPTER: return CACHE_TYPE_ADAPTER[id_model] - CACHE_TYPE_ADAPTER[id_model] = TypeAdapter(model) + if isinstance(model, TypeAdapter): + CACHE_TYPE_ADAPTER[id_model] = model + else: + CACHE_TYPE_ADAPTER[id_model] = TypeAdapter(model) + return CACHE_TYPE_ADAPTER[id_model] diff --git a/tests/functional/parser/test_parser.py b/tests/functional/parser/test_parser.py index c7c90b70265..a6857af520d 100644 --- a/tests/functional/parser/test_parser.py +++ b/tests/functional/parser/test_parser.py @@ -154,7 +154,41 @@ class FailedCallback(pydantic.BaseModel): DogCallback = Annotated[Union[SuccessfulCallback, FailedCallback], pydantic.Field(discriminator="status")] @event_parser(model=DogCallback) - def handler(event: test_input, _: Any) -> str: + def handler(event, _: Any) -> str: + if isinstance(event, FailedCallback): + return f"Uh oh. Had a problem: {event.error}" + + return f"Successfully retrieved {event.breed} named {event.name}" + + ret = handler(test_input, None) + assert ret == expected + + +@pytest.mark.parametrize( + "test_input,expected", + [ + ( + {"status": "succeeded", "name": "Clifford", "breed": "Labrador"}, + "Successfully retrieved Labrador named Clifford", + ), + ({"status": "failed", "error": "oh some error"}, "Uh oh. Had a problem: oh some error"), + ], +) +def test_parser_unions_with_type_adapter_instance(test_input, expected): + class SuccessfulCallback(pydantic.BaseModel): + status: Literal["succeeded"] + name: str + breed: Literal["Newfoundland", "Labrador"] + + class FailedCallback(pydantic.BaseModel): + status: Literal["failed"] + error: str + + DogCallback = Annotated[Union[SuccessfulCallback, FailedCallback], pydantic.Field(discriminator="status")] + DogCallbackTypeAdapter = pydantic.TypeAdapter(DogCallback) + + @event_parser(model=DogCallbackTypeAdapter) + def handler(event, _: Any) -> str: if isinstance(event, FailedCallback): return f"Uh oh. Had a problem: {event.error}"