From 060e11068cf647bc411a2b0273511083012d88dd Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Tue, 23 Jul 2024 15:15:27 +0300 Subject: [PATCH] rewrite based on PR comments --- src/hayhooks/server/utils/create_valid_type.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/hayhooks/server/utils/create_valid_type.py b/src/hayhooks/server/utils/create_valid_type.py index 0089f3b..54d5b2b 100644 --- a/src/hayhooks/server/utils/create_valid_type.py +++ b/src/hayhooks/server/utils/create_valid_type.py @@ -26,7 +26,13 @@ def _handle_generics(t_) -> GenericAlias: else: result = t child_typing.append(result) - return GenericAlias(get_origin(t_), tuple(child_typing)) + + if len(child_typing) == 2 and child_typing[1] is type(None): + # because TypedDict can't handle union types with None + # rewrite them as Optional[type] + return Optional[child_typing[0]] + else: + return GenericAlias(get_origin(t_), tuple(child_typing)) if isclass(type_): new_type = {} @@ -35,14 +41,6 @@ def _handle_generics(t_) -> GenericAlias: new_type[arg_name] = _handle_generics(arg_type) else: new_type[arg_name] = arg_type - if new_type: - # because TypedDict can't handle union types with None - # rewrite them as Optional[type] - for arg_name, arg_type in new_type.items(): - type_args = get_args(arg_type) - if len(type_args) == 2 and type_args[1] is type(None): - new_type[arg_name] = Optional[type_args[0]] - return TypedDict(type_.__name__, new_type) return type_