Skip to content

Commit

Permalink
skip callable types when creating pipeline ; recursively serialize co…
Browse files Browse the repository at this point in the history
…mponent output to avoid pydantic serialization errors (eg on OpenAI responses)
  • Loading branch information
mpangrazzi committed Nov 19, 2024
1 parent 8ebd76c commit d1adc2b
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 38 deletions.
62 changes: 33 additions & 29 deletions src/hayhooks/server/pipelines/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from pandas import DataFrame
from pydantic import BaseModel, ConfigDict, create_model
from typing import Callable

from hayhooks.server.utils.create_valid_type import handle_unsupported_types


Expand All @@ -28,14 +26,16 @@ def get_request_model(pipeline_name: str, pipeline_inputs):
component_model = {}
for name, typedef in inputs.items():
try:
input_type = handle_unsupported_types(typedef["type"], {DataFrame: dict, Callable: dict})
input_type = handle_unsupported_types(typedef["type"], {DataFrame: dict})
except TypeError as e:
print(f"ERROR at {component_name!r}, {name}: {typedef}")
raise e
component_model[name] = (
input_type,
typedef.get("default_value", ...),
)

if input_type is not None:
component_model[name] = (
input_type,
typedef.get("default_value", ...),
)
request_model[component_name] = (create_model("ComponentParams", **component_model, __config__=config), ...)

return create_model(f"{pipeline_name.capitalize()}RunRequest", **request_model, __config__=config)
Expand All @@ -62,30 +62,34 @@ def get_response_model(pipeline_name: str, pipeline_outputs):
return create_model(f"{pipeline_name.capitalize()}RunResponse", **response_model, __config__=config)


def convert_value_to_dict(value):
"""Convert a single value to a dictionary if possible"""
if hasattr(value, "to_dict"):
if "init_parameters" in value.to_dict():
return value.to_dict()["init_parameters"]
return value.to_dict()
elif hasattr(value, "model_dump"):
return value.model_dump()
elif isinstance(value, dict):
return {k: convert_value_to_dict(v) for k, v in value.items()}
elif isinstance(value, list):
return [convert_value_to_dict(item) for item in value]
else:
return value


def convert_component_output(component_output):
"""
Converts outputs from a component as a dict so that it can be validated against response model
Component output has this form:
Converts component outputs to dictionaries that can be validated against response model.
Handles nested structures recursively.
"documents":[
{"id":"818170...", "content":"RapidAPI for Mac is a full-featured HTTP client."}
]
Args:
component_output: Dict with component outputs
Returns:
Dict with all nested objects converted to dictionaries
"""
result = {}
for output_name, data in component_output.items():

def get_value(data):
if hasattr(data, "to_dict") and "init_parameters" in data.to_dict():
return data.to_dict()["init_parameters"]
elif hasattr(data, "to_dict"):
return data.to_dict()
else:
return data

if type(data) is list:
result[output_name] = [get_value(d) for d in data]
else:
result[output_name] = get_value(data)
return result
if isinstance(component_output, dict):
return {name: convert_value_to_dict(data) for name, data in component_output.items()}

return convert_value_to_dict(component_output)
16 changes: 8 additions & 8 deletions src/hayhooks/server/utils/create_valid_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,22 @@ def is_callable_type(t):
return False


def handle_unsupported_types(type_: type, types_mapping: Dict[type, type]) -> Union[GenericAlias, type]:
def handle_unsupported_types(
type_: type, types_mapping: Dict[type, type], skip_callables: bool = True
) -> Union[GenericAlias, type, None]:
"""
Recursively handle types that are not supported by Pydantic by replacing them with the given types mapping.
"""

def handle_generics(t_) -> GenericAlias:
def handle_generics(t_) -> Union[GenericAlias, None]:
"""Handle generics recursively"""
if is_callable_type(t_):
return types_mapping[Callable]
if is_callable_type(t_) and skip_callables:
return None

child_typing = []
for t in get_args(t_):
if t in types_mapping:
result = types_mapping[t]
elif is_callable_type(t):
result = types_mapping[Callable]
elif isclass(t):
result = handle_unsupported_types(t, types_mapping)
else:
Expand All @@ -49,8 +49,8 @@ def handle_generics(t_) -> GenericAlias:
else:
return GenericAlias(get_origin(t_), tuple(child_typing))

if is_callable_type(type_):
return types_mapping[Callable]
if is_callable_type(type_) and skip_callables:
return None

if isclass(type_):
new_type = {}
Expand Down
43 changes: 43 additions & 0 deletions tests/test_convert_component_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from hayhooks.server.pipelines.models import convert_component_output
from openai.types.completion_usage import CompletionTokensDetails, PromptTokensDetails


def test_convert_component_output_with_nested_models():
sample_response = [
{
'model': 'gpt-4o-mini-2024-07-18',
'index': 0,
'finish_reason': 'stop',
'usage': {
'completion_tokens': 52,
'prompt_tokens': 29,
'total_tokens': 81,
'completion_tokens_details': CompletionTokensDetails(
accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0
),
'prompt_tokens_details': PromptTokensDetails(audio_tokens=0, cached_tokens=0),
},
}
]

converted_output = convert_component_output(sample_response)

assert converted_output == [
{
'model': 'gpt-4o-mini-2024-07-18',
'index': 0,
'finish_reason': 'stop',
'usage': {
'completion_tokens': 52,
'prompt_tokens': 29,
'total_tokens': 81,
'completion_tokens_details': {
'accepted_prediction_tokens': 0,
'audio_tokens': 0,
'reasoning_tokens': 0,
'rejected_prediction_tokens': 0,
},
'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0},
},
}
]
3 changes: 2 additions & 1 deletion tests/test_handle_callable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_is_callable_type(t, expected):
assert is_callable_type(t) == expected


def test_handle_callable_type_when_creating_pipeline_models():
def test_skip_callables_when_creating_pipeline_models():
pipeline_name = "test_pipeline"
pipeline_inputs = {
"generator": {
Expand All @@ -51,3 +51,4 @@ def test_handle_callable_type_when_creating_pipeline_models():
# by the handle_unsupported_types function
assert request_model.model_json_schema() is not None
assert request_model.__name__ == "Test_pipelineRunRequest"
assert "streaming_callback" not in request_model.model_json_schema()["$defs"]["ComponentParams"]["properties"]

0 comments on commit d1adc2b

Please sign in to comment.