Skip to content

Commit

Permalink
Merge pull request #41 from mpangrazzi/main
Browse files Browse the repository at this point in the history
Handle Callable-like types in get_request_model (eg streaming_callback) + Fix component output serialization
  • Loading branch information
mpangrazzi authored Nov 19, 2024
2 parents 35837d8 + d1adc2b commit 29c25d8
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 44 deletions.
59 changes: 32 additions & 27 deletions src/hayhooks/server/pipelines/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from pandas import DataFrame
from pydantic import BaseModel, ConfigDict, create_model

from hayhooks.server.utils.create_valid_type import handle_unsupported_types


Expand Down Expand Up @@ -31,10 +30,12 @@ def get_request_model(pipeline_name: str, pipeline_inputs):
except TypeError as e:
print(f"ERROR at {component_name!r}, {name}: {typedef}")
raise e
component_model[name] = (
input_type,
typedef.get("default_value", ...),
)

if input_type is not None:
component_model[name] = (
input_type,
typedef.get("default_value", ...),
)
request_model[component_name] = (create_model("ComponentParams", **component_model, __config__=config), ...)

return create_model(f"{pipeline_name.capitalize()}RunRequest", **request_model, __config__=config)
Expand All @@ -61,30 +62,34 @@ def get_response_model(pipeline_name: str, pipeline_outputs):
return create_model(f"{pipeline_name.capitalize()}RunResponse", **response_model, __config__=config)


def convert_value_to_dict(value):
"""Convert a single value to a dictionary if possible"""
if hasattr(value, "to_dict"):
if "init_parameters" in value.to_dict():
return value.to_dict()["init_parameters"]
return value.to_dict()
elif hasattr(value, "model_dump"):
return value.model_dump()
elif isinstance(value, dict):
return {k: convert_value_to_dict(v) for k, v in value.items()}
elif isinstance(value, list):
return [convert_value_to_dict(item) for item in value]
else:
return value


def convert_component_output(component_output):
"""
Converts outputs from a component as a dict so that it can be validated against response model
Component output has this form:
Converts component outputs to dictionaries that can be validated against response model.
Handles nested structures recursively.
"documents":[
{"id":"818170...", "content":"RapidAPI for Mac is a full-featured HTTP client."}
]
Args:
component_output: Dict with component outputs
Returns:
Dict with all nested objects converted to dictionaries
"""
result = {}
for output_name, data in component_output.items():

def get_value(data):
if hasattr(data, "to_dict") and "init_parameters" in data.to_dict():
return data.to_dict()["init_parameters"]
elif hasattr(data, "to_dict"):
return data.to_dict()
else:
return data

if type(data) is list:
result[output_name] = [get_value(d) for d in data]
else:
result[output_name] = get_value(data)
return result
if isinstance(component_output, dict):
return {name: convert_value_to_dict(data) for name, data in component_output.items()}

return convert_value_to_dict(component_output)
48 changes: 33 additions & 15 deletions src/hayhooks/server/utils/create_valid_type.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,39 @@
from collections.abc import Callable as CallableABC
from inspect import isclass
from types import GenericAlias
from typing import Dict, Optional, Union, get_args, get_origin, get_type_hints
from typing import Callable, Dict, Optional, Union, get_args, get_origin, get_type_hints


def handle_unsupported_types(type_: type, types_mapping: Dict[type, type]) -> Union[GenericAlias, type]:
def is_callable_type(t):
"""Check if a type is any form of callable"""
if t in (Callable, CallableABC):
return True

# Check origin type
origin = get_origin(t)
if origin in (Callable, CallableABC):
return True

# Handle Optional/Union types
if origin in (Union, type(Optional[int])): # type(Optional[int]) handles runtime Optional type
args = get_args(t)
return any(is_callable_type(arg) for arg in args)

return False


def handle_unsupported_types(
type_: type, types_mapping: Dict[type, type], skip_callables: bool = True
) -> Union[GenericAlias, type, None]:
"""
Recursively handle types that are not supported by Pydantic by replacing them with the given types mapping.
:param type_: Type to replace if not supported
:param types_mapping: Mapping of types to replace
"""

def _handle_generics(t_) -> GenericAlias:
"""
Handle generics recursively
"""
def handle_generics(t_) -> Union[GenericAlias, None]:
"""Handle generics recursively"""
if is_callable_type(t_) and skip_callables:
return None

child_typing = []
for t in get_args(t_):
if t in types_mapping:
Expand All @@ -26,20 +45,19 @@ def _handle_generics(t_) -> GenericAlias:
child_typing.append(result)

if len(child_typing) == 2 and child_typing[1] is type(None):
# because TypedDict can't handle union types with None
# rewrite them as Optional[type]
return Optional[child_typing[0]]
else:
return GenericAlias(get_origin(t_), tuple(child_typing))

if is_callable_type(type_) and skip_callables:
return None

if isclass(type_):
new_type = {}
for arg_name, arg_type in get_type_hints(type_).items():
if get_args(arg_type):
new_type[arg_name] = _handle_generics(arg_type)
new_type[arg_name] = handle_generics(arg_type)
else:
new_type[arg_name] = arg_type

return type_

return _handle_generics(type_)
return handle_generics(type_)
5 changes: 3 additions & 2 deletions src/hayhooks/server/utils/deploy_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from fastapi.concurrency import run_in_threadpool
from fastapi.responses import JSONResponse

from hayhooks.server.pipelines import registry
from hayhooks.server.pipelines.models import (
PipelineDefinition,
convert_component_output,
get_request_model,
get_response_model,
convert_component_output,
)


def deploy_pipeline_def(app, pipeline_def: PipelineDefinition):
try:
pipe = registry.add(pipeline_def.name, pipeline_def.source_code)
Expand Down
43 changes: 43 additions & 0 deletions tests/test_convert_component_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from hayhooks.server.pipelines.models import convert_component_output
from openai.types.completion_usage import CompletionTokensDetails, PromptTokensDetails


def test_convert_component_output_with_nested_models():
sample_response = [
{
'model': 'gpt-4o-mini-2024-07-18',
'index': 0,
'finish_reason': 'stop',
'usage': {
'completion_tokens': 52,
'prompt_tokens': 29,
'total_tokens': 81,
'completion_tokens_details': CompletionTokensDetails(
accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0
),
'prompt_tokens_details': PromptTokensDetails(audio_tokens=0, cached_tokens=0),
},
}
]

converted_output = convert_component_output(sample_response)

assert converted_output == [
{
'model': 'gpt-4o-mini-2024-07-18',
'index': 0,
'finish_reason': 'stop',
'usage': {
'completion_tokens': 52,
'prompt_tokens': 29,
'total_tokens': 81,
'completion_tokens_details': {
'accepted_prediction_tokens': 0,
'audio_tokens': 0,
'reasoning_tokens': 0,
'rejected_prediction_tokens': 0,
},
'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0},
},
}
]
54 changes: 54 additions & 0 deletions tests/test_handle_callable_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from collections.abc import Callable as CallableABC
from types import NoneType
from typing import Any, Callable, Optional, Union

import haystack
import pytest

from hayhooks.server.pipelines.models import get_request_model
from hayhooks.server.utils.create_valid_type import is_callable_type


@pytest.mark.parametrize(
"t, expected",
[
(Callable, True),
(CallableABC, True),
(Callable[[int], str], True),
(Callable[..., Any], True),
(int, False),
(str, False),
(Any, False),
(Union[int, str], False),
(Optional[Callable[[haystack.dataclasses.streaming_chunk.StreamingChunk], NoneType]], True),
],
)
def test_is_callable_type(t, expected):
assert is_callable_type(t) == expected


def test_skip_callables_when_creating_pipeline_models():
pipeline_name = "test_pipeline"
pipeline_inputs = {
"generator": {
"system_prompt": {"type": Optional[str], "is_mandatory": False, "default_value": None},
"streaming_callback": {
"type": Optional[Callable[[haystack.dataclasses.streaming_chunk.StreamingChunk], NoneType]],
"is_mandatory": False,
"default_value": None,
},
"generation_kwargs": {
"type": Optional[dict[str, Any]],
"is_mandatory": False,
"default_value": None,
},
}
}

request_model = get_request_model(pipeline_name, pipeline_inputs)

# This line used to throw an error because the Callable type was not handled correctly
# by the handle_unsupported_types function
assert request_model.model_json_schema() is not None
assert request_model.__name__ == "Test_pipelineRunRequest"
assert "streaming_callback" not in request_model.model_json_schema()["$defs"]["ComponentParams"]["properties"]

0 comments on commit 29c25d8

Please sign in to comment.