-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #41 from mpangrazzi/main
Handle Callable-like types in get_request_model (eg streaming_callback) + Fix component output serialization
- Loading branch information
Showing
5 changed files
with
165 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from hayhooks.server.pipelines.models import convert_component_output | ||
from openai.types.completion_usage import CompletionTokensDetails, PromptTokensDetails | ||
|
||
|
||
def test_convert_component_output_with_nested_models(): | ||
sample_response = [ | ||
{ | ||
'model': 'gpt-4o-mini-2024-07-18', | ||
'index': 0, | ||
'finish_reason': 'stop', | ||
'usage': { | ||
'completion_tokens': 52, | ||
'prompt_tokens': 29, | ||
'total_tokens': 81, | ||
'completion_tokens_details': CompletionTokensDetails( | ||
accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0 | ||
), | ||
'prompt_tokens_details': PromptTokensDetails(audio_tokens=0, cached_tokens=0), | ||
}, | ||
} | ||
] | ||
|
||
converted_output = convert_component_output(sample_response) | ||
|
||
assert converted_output == [ | ||
{ | ||
'model': 'gpt-4o-mini-2024-07-18', | ||
'index': 0, | ||
'finish_reason': 'stop', | ||
'usage': { | ||
'completion_tokens': 52, | ||
'prompt_tokens': 29, | ||
'total_tokens': 81, | ||
'completion_tokens_details': { | ||
'accepted_prediction_tokens': 0, | ||
'audio_tokens': 0, | ||
'reasoning_tokens': 0, | ||
'rejected_prediction_tokens': 0, | ||
}, | ||
'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}, | ||
}, | ||
} | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from collections.abc import Callable as CallableABC | ||
from types import NoneType | ||
from typing import Any, Callable, Optional, Union | ||
|
||
import haystack | ||
import pytest | ||
|
||
from hayhooks.server.pipelines.models import get_request_model | ||
from hayhooks.server.utils.create_valid_type import is_callable_type | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"t, expected", | ||
[ | ||
(Callable, True), | ||
(CallableABC, True), | ||
(Callable[[int], str], True), | ||
(Callable[..., Any], True), | ||
(int, False), | ||
(str, False), | ||
(Any, False), | ||
(Union[int, str], False), | ||
(Optional[Callable[[haystack.dataclasses.streaming_chunk.StreamingChunk], NoneType]], True), | ||
], | ||
) | ||
def test_is_callable_type(t, expected): | ||
assert is_callable_type(t) == expected | ||
|
||
|
||
def test_skip_callables_when_creating_pipeline_models(): | ||
pipeline_name = "test_pipeline" | ||
pipeline_inputs = { | ||
"generator": { | ||
"system_prompt": {"type": Optional[str], "is_mandatory": False, "default_value": None}, | ||
"streaming_callback": { | ||
"type": Optional[Callable[[haystack.dataclasses.streaming_chunk.StreamingChunk], NoneType]], | ||
"is_mandatory": False, | ||
"default_value": None, | ||
}, | ||
"generation_kwargs": { | ||
"type": Optional[dict[str, Any]], | ||
"is_mandatory": False, | ||
"default_value": None, | ||
}, | ||
} | ||
} | ||
|
||
request_model = get_request_model(pipeline_name, pipeline_inputs) | ||
|
||
# This line used to throw an error because the Callable type was not handled correctly | ||
# by the handle_unsupported_types function | ||
assert request_model.model_json_schema() is not None | ||
assert request_model.__name__ == "Test_pipelineRunRequest" | ||
assert "streaming_callback" not in request_model.model_json_schema()["$defs"]["ComponentParams"]["properties"] |