From 3a23052e9bdc467656ce08462ce22bfd7277b88b Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Fri, 28 Jun 2024 12:52:29 +0300 Subject: [PATCH 1/3] convert union type with none to optional --- src/hayhooks/server/pipelines/models.py | 11 +++++++++-- src/hayhooks/server/utils/create_valid_type.py | 8 +++++++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/hayhooks/server/pipelines/models.py b/src/hayhooks/server/pipelines/models.py index f06f9a8..35de7f7 100644 --- a/src/hayhooks/server/pipelines/models.py +++ b/src/hayhooks/server/pipelines/models.py @@ -26,7 +26,11 @@ def get_request_model(pipeline_name: str, pipeline_inputs): for component_name, inputs in pipeline_inputs.items(): component_model = {} for name, typedef in inputs.items(): - input_type = handle_unsupported_types(typedef["type"], {DataFrame: dict}) + try: + input_type = handle_unsupported_types(typedef["type"], {DataFrame: dict}) + except TypeError as e: + print(f"ERROR at {component_name!r}, {name}: {typedef}") + raise e component_model[name] = ( input_type, typedef.get("default_value", ...), @@ -70,7 +74,10 @@ def convert_component_output(component_output): """ result = {} for output_name, data in component_output.items(): - get_value = lambda data: data.to_dict()["init_parameters"] if hasattr(data, "to_dict") else data + + def get_value(data): + return data.to_dict()["init_parameters"] if hasattr(data, "to_dict") else data + if type(data) is list: result[output_name] = [get_value(d) for d in data] else: diff --git a/src/hayhooks/server/utils/create_valid_type.py b/src/hayhooks/server/utils/create_valid_type.py index 95307f7..0089f3b 100644 --- a/src/hayhooks/server/utils/create_valid_type.py +++ b/src/hayhooks/server/utils/create_valid_type.py @@ -1,6 +1,6 @@ from inspect import isclass from types import GenericAlias -from typing import Dict, Union, get_args, get_origin, get_type_hints +from typing import Dict, Union, Optional, get_args, get_origin, get_type_hints from typing_extensions import TypedDict @@ -36,6 +36,12 @@ def _handle_generics(t_) -> GenericAlias: else: new_type[arg_name] = arg_type if new_type: + # because TypedDict can't handle union types with None + # rewrite them as Optional[type] + for arg_name, arg_type in new_type.items(): + type_args = get_args(arg_type) + if len(type_args) == 2 and type_args[1] is type(None): + new_type[arg_name] = Optional[type_args[0]] return TypedDict(type_.__name__, new_type) return type_ From 060e11068cf647bc411a2b0273511083012d88dd Mon Sep 17 00:00:00 2001 From: Rustem Galiullin Date: Tue, 23 Jul 2024 15:15:27 +0300 Subject: [PATCH 2/3] rewrite based on PR comments --- src/hayhooks/server/utils/create_valid_type.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/hayhooks/server/utils/create_valid_type.py b/src/hayhooks/server/utils/create_valid_type.py index 0089f3b..54d5b2b 100644 --- a/src/hayhooks/server/utils/create_valid_type.py +++ b/src/hayhooks/server/utils/create_valid_type.py @@ -26,7 +26,13 @@ def _handle_generics(t_) -> GenericAlias: else: result = t child_typing.append(result) - return GenericAlias(get_origin(t_), tuple(child_typing)) + + if len(child_typing) == 2 and child_typing[1] is type(None): + # because TypedDict can't handle union types with None + # rewrite them as Optional[type] + return Optional[child_typing[0]] + else: + return GenericAlias(get_origin(t_), tuple(child_typing)) if isclass(type_): new_type = {} @@ -35,14 +41,6 @@ def _handle_generics(t_) -> GenericAlias: new_type[arg_name] = _handle_generics(arg_type) else: new_type[arg_name] = arg_type - if new_type: - # because TypedDict can't handle union types with None - # rewrite them as Optional[type] - for arg_name, arg_type in new_type.items(): - type_args = get_args(arg_type) - if len(type_args) == 2 and type_args[1] is type(None): - new_type[arg_name] = Optional[type_args[0]] - return TypedDict(type_.__name__, new_type) return type_ From a7f07bf3cff62e42718b62d7cc21c82722001184 Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Wed, 31 Jul 2024 13:12:25 +0200 Subject: [PATCH 3/3] Remove unused import --- src/hayhooks/server/utils/create_valid_type.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/hayhooks/server/utils/create_valid_type.py b/src/hayhooks/server/utils/create_valid_type.py index 54d5b2b..906a84b 100644 --- a/src/hayhooks/server/utils/create_valid_type.py +++ b/src/hayhooks/server/utils/create_valid_type.py @@ -1,8 +1,6 @@ from inspect import isclass from types import GenericAlias -from typing import Dict, Union, Optional, get_args, get_origin, get_type_hints - -from typing_extensions import TypedDict +from typing import Dict, Optional, Union, get_args, get_origin, get_type_hints def handle_unsupported_types(type_: type, types_mapping: Dict[type, type]) -> Union[GenericAlias, type]: