diff --git a/src/hayhooks/server/pipelines/models.py b/src/hayhooks/server/pipelines/models.py index 51d37ad..3d36f4a 100644 --- a/src/hayhooks/server/pipelines/models.py +++ b/src/hayhooks/server/pipelines/models.py @@ -1,6 +1,5 @@ from pandas import DataFrame from pydantic import BaseModel, ConfigDict, create_model - from hayhooks.server.utils.create_valid_type import handle_unsupported_types @@ -31,10 +30,12 @@ def get_request_model(pipeline_name: str, pipeline_inputs): except TypeError as e: print(f"ERROR at {component_name!r}, {name}: {typedef}") raise e - component_model[name] = ( - input_type, - typedef.get("default_value", ...), - ) + + if input_type is not None: + component_model[name] = ( + input_type, + typedef.get("default_value", ...), + ) request_model[component_name] = (create_model("ComponentParams", **component_model, __config__=config), ...) return create_model(f"{pipeline_name.capitalize()}RunRequest", **request_model, __config__=config) @@ -61,30 +62,34 @@ def get_response_model(pipeline_name: str, pipeline_outputs): return create_model(f"{pipeline_name.capitalize()}RunResponse", **response_model, __config__=config) +def convert_value_to_dict(value): + """Convert a single value to a dictionary if possible""" + if hasattr(value, "to_dict"): + if "init_parameters" in value.to_dict(): + return value.to_dict()["init_parameters"] + return value.to_dict() + elif hasattr(value, "model_dump"): + return value.model_dump() + elif isinstance(value, dict): + return {k: convert_value_to_dict(v) for k, v in value.items()} + elif isinstance(value, list): + return [convert_value_to_dict(item) for item in value] + else: + return value + + def convert_component_output(component_output): """ - Converts outputs from a component as a dict so that it can be validated against response model - - Component output has this form: + Converts component outputs to dictionaries that can be validated against response model. + Handles nested structures recursively. - "documents":[ - {"id":"818170...", "content":"RapidAPI for Mac is a full-featured HTTP client."} - ] + Args: + component_output: Dict with component outputs + Returns: + Dict with all nested objects converted to dictionaries """ - result = {} - for output_name, data in component_output.items(): - - def get_value(data): - if hasattr(data, "to_dict") and "init_parameters" in data.to_dict(): - return data.to_dict()["init_parameters"] - elif hasattr(data, "to_dict"): - return data.to_dict() - else: - return data - - if type(data) is list: - result[output_name] = [get_value(d) for d in data] - else: - result[output_name] = get_value(data) - return result + if isinstance(component_output, dict): + return {name: convert_value_to_dict(data) for name, data in component_output.items()} + + return convert_value_to_dict(component_output) diff --git a/src/hayhooks/server/utils/create_valid_type.py b/src/hayhooks/server/utils/create_valid_type.py index 906a84b..6c02f9e 100644 --- a/src/hayhooks/server/utils/create_valid_type.py +++ b/src/hayhooks/server/utils/create_valid_type.py @@ -1,20 +1,39 @@ +from collections.abc import Callable as CallableABC from inspect import isclass from types import GenericAlias -from typing import Dict, Optional, Union, get_args, get_origin, get_type_hints +from typing import Callable, Dict, Optional, Union, get_args, get_origin, get_type_hints -def handle_unsupported_types(type_: type, types_mapping: Dict[type, type]) -> Union[GenericAlias, type]: +def is_callable_type(t): + """Check if a type is any form of callable""" + if t in (Callable, CallableABC): + return True + + # Check origin type + origin = get_origin(t) + if origin in (Callable, CallableABC): + return True + + # Handle Optional/Union types + if origin in (Union, type(Optional[int])): # type(Optional[int]) handles runtime Optional type + args = get_args(t) + return any(is_callable_type(arg) for arg in args) + + return False + + +def handle_unsupported_types( + type_: type, types_mapping: Dict[type, type], skip_callables: bool = True +) -> Union[GenericAlias, type, None]: """ Recursively handle types that are not supported by Pydantic by replacing them with the given types mapping. - - :param type_: Type to replace if not supported - :param types_mapping: Mapping of types to replace """ - def _handle_generics(t_) -> GenericAlias: - """ - Handle generics recursively - """ + def handle_generics(t_) -> Union[GenericAlias, None]: + """Handle generics recursively""" + if is_callable_type(t_) and skip_callables: + return None + child_typing = [] for t in get_args(t_): if t in types_mapping: @@ -26,20 +45,19 @@ def _handle_generics(t_) -> GenericAlias: child_typing.append(result) if len(child_typing) == 2 and child_typing[1] is type(None): - # because TypedDict can't handle union types with None - # rewrite them as Optional[type] return Optional[child_typing[0]] else: return GenericAlias(get_origin(t_), tuple(child_typing)) + if is_callable_type(type_) and skip_callables: + return None + if isclass(type_): new_type = {} for arg_name, arg_type in get_type_hints(type_).items(): if get_args(arg_type): - new_type[arg_name] = _handle_generics(arg_type) + new_type[arg_name] = handle_generics(arg_type) else: new_type[arg_name] = arg_type - return type_ - - return _handle_generics(type_) + return handle_generics(type_) diff --git a/src/hayhooks/server/utils/deploy_utils.py b/src/hayhooks/server/utils/deploy_utils.py index 9a5d762..2d42742 100644 --- a/src/hayhooks/server/utils/deploy_utils.py +++ b/src/hayhooks/server/utils/deploy_utils.py @@ -1,15 +1,16 @@ from fastapi import HTTPException -from fastapi.responses import JSONResponse from fastapi.concurrency import run_in_threadpool +from fastapi.responses import JSONResponse from hayhooks.server.pipelines import registry from hayhooks.server.pipelines.models import ( PipelineDefinition, + convert_component_output, get_request_model, get_response_model, - convert_component_output, ) + def deploy_pipeline_def(app, pipeline_def: PipelineDefinition): try: pipe = registry.add(pipeline_def.name, pipeline_def.source_code) diff --git a/tests/test_convert_component_output.py b/tests/test_convert_component_output.py new file mode 100644 index 0000000..df5cb68 --- /dev/null +++ b/tests/test_convert_component_output.py @@ -0,0 +1,43 @@ +from hayhooks.server.pipelines.models import convert_component_output +from openai.types.completion_usage import CompletionTokensDetails, PromptTokensDetails + + +def test_convert_component_output_with_nested_models(): + sample_response = [ + { + 'model': 'gpt-4o-mini-2024-07-18', + 'index': 0, + 'finish_reason': 'stop', + 'usage': { + 'completion_tokens': 52, + 'prompt_tokens': 29, + 'total_tokens': 81, + 'completion_tokens_details': CompletionTokensDetails( + accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0 + ), + 'prompt_tokens_details': PromptTokensDetails(audio_tokens=0, cached_tokens=0), + }, + } + ] + + converted_output = convert_component_output(sample_response) + + assert converted_output == [ + { + 'model': 'gpt-4o-mini-2024-07-18', + 'index': 0, + 'finish_reason': 'stop', + 'usage': { + 'completion_tokens': 52, + 'prompt_tokens': 29, + 'total_tokens': 81, + 'completion_tokens_details': { + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + }, + 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}, + }, + } + ] diff --git a/tests/test_handle_callable_type.py b/tests/test_handle_callable_type.py new file mode 100644 index 0000000..529b6cf --- /dev/null +++ b/tests/test_handle_callable_type.py @@ -0,0 +1,54 @@ +from collections.abc import Callable as CallableABC +from types import NoneType +from typing import Any, Callable, Optional, Union + +import haystack +import pytest + +from hayhooks.server.pipelines.models import get_request_model +from hayhooks.server.utils.create_valid_type import is_callable_type + + +@pytest.mark.parametrize( + "t, expected", + [ + (Callable, True), + (CallableABC, True), + (Callable[[int], str], True), + (Callable[..., Any], True), + (int, False), + (str, False), + (Any, False), + (Union[int, str], False), + (Optional[Callable[[haystack.dataclasses.streaming_chunk.StreamingChunk], NoneType]], True), + ], +) +def test_is_callable_type(t, expected): + assert is_callable_type(t) == expected + + +def test_skip_callables_when_creating_pipeline_models(): + pipeline_name = "test_pipeline" + pipeline_inputs = { + "generator": { + "system_prompt": {"type": Optional[str], "is_mandatory": False, "default_value": None}, + "streaming_callback": { + "type": Optional[Callable[[haystack.dataclasses.streaming_chunk.StreamingChunk], NoneType]], + "is_mandatory": False, + "default_value": None, + }, + "generation_kwargs": { + "type": Optional[dict[str, Any]], + "is_mandatory": False, + "default_value": None, + }, + } + } + + request_model = get_request_model(pipeline_name, pipeline_inputs) + + # This line used to throw an error because the Callable type was not handled correctly + # by the handle_unsupported_types function + assert request_model.model_json_schema() is not None + assert request_model.__name__ == "Test_pipelineRunRequest" + assert "streaming_callback" not in request_model.model_json_schema()["$defs"]["ComponentParams"]["properties"]