From 44b010338ba228cfeb627ba92bb63df30eb6a03f Mon Sep 17 00:00:00 2001 From: anakin87 Date: Thu, 19 Dec 2024 11:11:34 +0100 Subject: [PATCH] hfapi w tools --- .../generators/chat/hugging_face_api.py | 150 ++++-- haystack/dataclasses/tool.py | 15 +- haystack/utils/hf.py | 2 +- .../generators/test_hugging_face_api.py | 499 +++++++++++++----- test/dataclasses/test_tool.py | 16 + 5 files changed, 502 insertions(+), 180 deletions(-) diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index 8711a9175a..cc6462018e 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -5,30 +5,25 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Union from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall +from haystack.dataclasses.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.lazy_imports import LazyImport from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable -from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model +from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format from haystack.utils.url_validation import is_valid_http_url with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.23.0\"'") as huggingface_hub_import: - from huggingface_hub import ChatCompletionOutput, ChatCompletionStreamOutput, InferenceClient + from huggingface_hub import ( + ChatCompletionInputTool, + ChatCompletionOutput, + ChatCompletionStreamOutput, + InferenceClient, + ) logger = logging.getLogger(__name__) -def _convert_message_to_hfapi_format(message: ChatMessage) -> Dict[str, str]: - """ - Convert a message to the format expected by Hugging Face APIs. - - :returns: A dictionary with the following keys: - - `role` - - `content` - """ - return {"role": message.role.value, "content": message.text or ""} - - @component class HuggingFaceAPIChatGenerator: """ @@ -107,6 +102,7 @@ def __init__( # pylint: disable=too-many-positional-arguments generation_kwargs: Optional[Dict[str, Any]] = None, stop_words: Optional[List[str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + tools: Optional[List[Tool]] = None, ): """ Initialize the HuggingFaceAPIChatGenerator instance. @@ -121,14 +117,22 @@ def __init__( # pylint: disable=too-many-positional-arguments - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`. - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or `TEXT_GENERATION_INFERENCE`. - :param token: The Hugging Face token to use as HTTP bearer authorization. + :param token: + The Hugging Face token to use as HTTP bearer authorization. Check your HF token in your [account settings](https://huggingface.co/settings/tokens). :param generation_kwargs: A dictionary with keyword arguments to customize text generation. Some examples: `max_tokens`, `temperature`, `top_p`. For details, see [Hugging Face chat_completion documentation](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion). - :param stop_words: An optional list of strings representing the stop words. - :param streaming_callback: An optional callable for handling streaming responses. + :param stop_words: + An optional list of strings representing the stop words. + :param streaming_callback: + An optional callable for handling streaming responses. + :param tools: + A list of tools for which the model can prepare calls. + The chosen model should support tool/function calling, according to the model card. + Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience + unexpected behavior. """ huggingface_hub_import.check() @@ -159,6 +163,11 @@ def __init__( # pylint: disable=too-many-positional-arguments msg = f"Unknown api_type {api_type}" raise ValueError(msg) + if tools: + if streaming_callback is not None: + raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") + _check_duplicate_tool_names(tools) + # handle generation kwargs setup generation_kwargs = generation_kwargs.copy() if generation_kwargs else {} generation_kwargs["stop"] = generation_kwargs.get("stop", []) @@ -171,6 +180,7 @@ def __init__( # pylint: disable=too-many-positional-arguments self.generation_kwargs = generation_kwargs self.streaming_callback = streaming_callback self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None) + self.tools = tools def to_dict(self) -> Dict[str, Any]: """ @@ -180,6 +190,7 @@ def to_dict(self) -> Dict[str, Any]: A dictionary containing the serialized component. """ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None return default_to_dict( self, api_type=str(self.api_type), @@ -187,6 +198,7 @@ def to_dict(self) -> Dict[str, Any]: token=self.token.to_dict() if self.token else None, generation_kwargs=self.generation_kwargs, streaming_callback=callback_name, + tools=serialized_tools, ) @classmethod @@ -195,6 +207,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator": Deserialize this component from a dictionary. """ deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + deserialize_tools_inplace(data["init_parameters"], key="tools") init_params = data.get("init_parameters", {}) serialized_callback_handler = init_params.get("streaming_callback") if serialized_callback_handler: @@ -202,12 +215,22 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator": return default_from_dict(cls, data) @component.output_types(replies=List[ChatMessage]) - def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): + def run( + self, + messages: List[ChatMessage], + generation_kwargs: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, + ): """ Invoke the text generation inference based on the provided messages and generation parameters. - :param messages: A list of ChatMessage objects representing the input messages. - :param generation_kwargs: Additional keyword arguments for text generation. + :param messages: + A list of ChatMessage objects representing the input messages. + :param generation_kwargs: + Additional keyword arguments for text generation. + :param tools: + A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set + during component initialization. :returns: A dictionary with the following keys: - `replies`: A list containing the generated responses as ChatMessage objects. """ @@ -215,12 +238,22 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, # update generation kwargs by merging with the default ones generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - formatted_messages = [_convert_message_to_hfapi_format(message) for message in messages] + formatted_messages = [convert_message_to_hf_format(message) for message in messages] + + tools = tools or self.tools + if tools: + if self.streaming_callback: + raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") + _check_duplicate_tool_names(tools) if self.streaming_callback: return self._run_streaming(formatted_messages, generation_kwargs) - return self._run_non_streaming(formatted_messages, generation_kwargs) + hf_tools = None + if tools: + hf_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools] + + return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools) def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]): api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion( @@ -229,11 +262,17 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict generated_text = "" - for chunk in api_output: # pylint: disable=not-an-iterable - text = chunk.choices[0].delta.content + for chunk in api_output: + # n is unused, so the API always returns only one choice + # the argument is probably allowed for compatibility with OpenAI + # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n + choice = chunk.choices[0] + + text = choice.delta.content if text: generated_text += text - finish_reason = chunk.choices[0].finish_reason + + finish_reason = choice.finish_reason meta = {} if finish_reason: @@ -242,8 +281,7 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict stream_chunk = StreamingChunk(text, meta) self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method) - message = ChatMessage.from_assistant(generated_text) - message.meta.update( + meta.update( { "model": self._client.model, "finish_reason": finish_reason, @@ -251,24 +289,48 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict "usage": {"prompt_tokens": 0, "completion_tokens": 0}, # not available in streaming } ) + + message = ChatMessage.from_assistant(text=generated_text, meta=meta) + return {"replies": [message]} def _run_non_streaming( - self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any] + self, + messages: List[Dict[str, str]], + generation_kwargs: Dict[str, Any], + tools: Optional[List["ChatCompletionInputTool"]] = None, ) -> Dict[str, List[ChatMessage]]: - chat_messages: List[ChatMessage] = [] - - api_chat_output: ChatCompletionOutput = self._client.chat_completion(messages, **generation_kwargs) - for choice in api_chat_output.choices: - message = ChatMessage.from_assistant(choice.message.content) - message.meta.update( - { - "model": self._client.model, - "finish_reason": choice.finish_reason, - "index": choice.index, - "usage": api_chat_output.usage or {"prompt_tokens": 0, "completion_tokens": 0}, - } - ) - chat_messages.append(message) - - return {"replies": chat_messages} + api_chat_output: ChatCompletionOutput = self._client.chat_completion( + messages=messages, tools=tools, **generation_kwargs + ) + + if len(api_chat_output.choices) == 0: + return {"replies": []} + + # n is unused, so the API always returns only one choice + # the argument is probably allowed for compatibility with OpenAI + # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n + choice = api_chat_output.choices[0] + + text = choice.message.content + tool_calls = [] + + if hfapi_tool_calls := choice.message.tool_calls: + for hfapi_tc in hfapi_tool_calls: + tool_call = ToolCall( + tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id + ) + tool_calls.append(tool_call) + + meta = { + "model": self._client.model, + "finish_reason": choice.finish_reason, + "index": choice.index, + "usage": { + "prompt_tokens": api_chat_output.usage.prompt_tokens, + "completion_tokens": api_chat_output.usage.completion_tokens, + }, + } + + message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta) + return {"replies": [message]} diff --git a/haystack/dataclasses/tool.py b/haystack/dataclasses/tool.py index 3df3fd18f2..4aaf1e2bd1 100644 --- a/haystack/dataclasses/tool.py +++ b/haystack/dataclasses/tool.py @@ -4,7 +4,7 @@ import inspect from dataclasses import asdict, dataclass -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional from pydantic import create_model @@ -216,6 +216,19 @@ def _remove_title_from_schema(schema: Dict[str, Any]): del property_schema[key] +def _check_duplicate_tool_names(tools: List[Tool]) -> None: + """ + Check for duplicate tool names. + + :param tools: The list of tools to check. + :raises ValueError: If duplicate tool names are found. + """ + tool_names = [tool.name for tool in tools] + duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1} + if duplicate_tool_names: + raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}") + + def deserialize_tools_inplace(data: Dict[str, Any], key: str = "tools"): """ Deserialize Tools in a dictionary inplace. diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index dbff3f22dc..6bc8169685 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -280,7 +280,7 @@ def convert_message_to_hf_format(message: ChatMessage) -> Dict[str, Any]: if not text_contents and not tool_calls and not tool_call_results: raise ValueError("A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`.") - elif len(text_contents) + len(tool_call_results) > 1: + if len(text_contents) + len(tool_call_results) > 1: raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.") # HF always expects a content field, even if it is empty diff --git a/test/components/generators/test_hugging_face_api.py b/test/components/generators/test_hugging_face_api.py index 0f4be2f9cb..0d0857e22a 100644 --- a/test/components/generators/test_hugging_face_api.py +++ b/test/components/generators/test_hugging_face_api.py @@ -5,38 +5,78 @@ from unittest.mock import MagicMock, Mock, patch import pytest +from haystack import Pipeline +from haystack.dataclasses import StreamingChunk +from haystack.utils.auth import Secret +from haystack.utils.hf import HFGenerationAPIType from huggingface_hub import ( - TextGenerationOutputToken, - TextGenerationStreamOutput, - TextGenerationStreamOutputStreamDetails, + ChatCompletionOutput, + ChatCompletionOutputComplete, + ChatCompletionOutputFunctionDefinition, + ChatCompletionOutputMessage, + ChatCompletionOutputToolCall, + ChatCompletionOutputUsage, + ChatCompletionStreamOutput, + ChatCompletionStreamOutputChoice, + ChatCompletionStreamOutputDelta, ) from huggingface_hub.utils import RepositoryNotFoundError -from haystack.components.generators import HuggingFaceAPIGenerator -from haystack.dataclasses import StreamingChunk -from haystack.utils.auth import Secret -from haystack.utils.hf import HFGenerationAPIType +from haystack.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator +from haystack.dataclasses import ChatMessage, Tool, ToolCall + + +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_system("You are a helpful assistant speaking A2 level of English"), + ChatMessage.from_user("Tell me about Berlin"), + ] + + +@pytest.fixture +def tools(): + tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters=tool_parameters, + function=lambda x: x, + ) + + return [tool] @pytest.fixture def mock_check_valid_model(): with patch( - "haystack.components.generators.hugging_face_api.check_valid_model", MagicMock(return_value=None) + "haystack.components.generators.chat.hugging_face_api.check_valid_model", MagicMock(return_value=None) ) as mock: yield mock @pytest.fixture -def mock_text_generation(): - with patch("huggingface_hub.InferenceClient.text_generation", autospec=True) as mock_text_generation: - mock_response = Mock() - mock_response.generated_text = "I'm fine, thanks." - details = Mock() - details.finish_reason = MagicMock(field1="value") - details.tokens = [1, 2, 3] - mock_response.details = details - mock_text_generation.return_value = mock_response - yield mock_text_generation +def mock_chat_completion(): + # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.example + + with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion: + completion = ChatCompletionOutput( + choices=[ + ChatCompletionOutputComplete( + finish_reason="eos_token", + index=0, + message=ChatCompletionOutputMessage(content="The capital of France is Paris.", role="assistant"), + ) + ], + id="some_id", + model="some_model", + system_fingerprint="some_fingerprint", + usage=ChatCompletionOutputUsage(completion_tokens=8, prompt_tokens=17, total_tokens=25), + created=1710498360, + ) + + mock_chat_completion.return_value = completion + yield mock_chat_completion # used to test serialization of streaming_callback @@ -44,10 +84,10 @@ def streaming_callback_handler(x): return x -class TestHuggingFaceAPIGenerator: +class TestHuggingFaceAPIChatGenerator: def test_init_invalid_api_type(self): with pytest.raises(ValueError): - HuggingFaceAPIGenerator(api_type="invalid_api_type", api_params={}) + HuggingFaceAPIChatGenerator(api_type="invalid_api_type", api_params={}) def test_init_serverless(self, mock_check_valid_model): model = "HuggingFaceH4/zephyr-7b-alpha" @@ -55,7 +95,7 @@ def test_init_serverless(self, mock_check_valid_model): stop_words = ["stop"] streaming_callback = None - generator = HuggingFaceAPIGenerator( + generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": model}, token=None, @@ -66,23 +106,42 @@ def test_init_serverless(self, mock_check_valid_model): assert generator.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API assert generator.api_params == {"model": model} - assert generator.generation_kwargs == { - **generation_kwargs, - **{"stop_sequences": ["stop"]}, - **{"max_new_tokens": 512}, - } + assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} assert generator.streaming_callback == streaming_callback + assert generator.tools is None + + def test_init_serverless_with_tools(self, mock_check_valid_model, tools): + model = "HuggingFaceH4/zephyr-7b-alpha" + generation_kwargs = {"temperature": 0.6} + stop_words = ["stop"] + streaming_callback = None + + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": model}, + token=None, + generation_kwargs=generation_kwargs, + stop_words=stop_words, + streaming_callback=streaming_callback, + tools=tools, + ) + + assert generator.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API + assert generator.api_params == {"model": model} + assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} + assert generator.streaming_callback == streaming_callback + assert generator.tools == tools def test_init_serverless_invalid_model(self, mock_check_valid_model): mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") with pytest.raises(RepositoryNotFoundError): - HuggingFaceAPIGenerator( + HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "invalid_model_id"} ) def test_init_serverless_no_model(self): with pytest.raises(ValueError): - HuggingFaceAPIGenerator( + HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"param": "irrelevant"} ) @@ -92,7 +151,7 @@ def test_init_tgi(self): stop_words = ["stop"] streaming_callback = None - generator = HuggingFaceAPIGenerator( + generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"url": url}, token=None, @@ -103,31 +162,49 @@ def test_init_tgi(self): assert generator.api_type == HFGenerationAPIType.TEXT_GENERATION_INFERENCE assert generator.api_params == {"url": url} - assert generator.generation_kwargs == { - **generation_kwargs, - **{"stop_sequences": ["stop"]}, - **{"max_new_tokens": 512}, - } + assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} assert generator.streaming_callback == streaming_callback + assert generator.tools is None def test_init_tgi_invalid_url(self): with pytest.raises(ValueError): - HuggingFaceAPIGenerator( + HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"url": "invalid_url"} ) def test_init_tgi_no_url(self): with pytest.raises(ValueError): - HuggingFaceAPIGenerator( + HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"param": "irrelevant"} ) + def test_init_fail_with_duplicate_tool_names(self, mock_check_valid_model, tools): + duplicate_tools = [tools[0], tools[0]] + with pytest.raises(ValueError): + HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "irrelevant"}, + tools=duplicate_tools, + ) + + def test_init_fail_with_tools_and_streaming(self, mock_check_valid_model, tools): + with pytest.raises(ValueError): + HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "irrelevant"}, + tools=tools, + streaming_callback=streaming_callback_handler, + ) + def test_to_dict(self, mock_check_valid_model): - generator = HuggingFaceAPIGenerator( + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + + generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], + tools=[tool], ) result = generator.to_dict() @@ -136,101 +213,118 @@ def test_to_dict(self, mock_check_valid_model): assert init_params["api_type"] == "serverless_inference_api" assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"} assert init_params["token"] == {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"} - assert init_params["generation_kwargs"] == { - "temperature": 0.6, - "stop_sequences": ["stop", "words"], - "max_new_tokens": 512, - } + assert init_params["generation_kwargs"] == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} + assert init_params["streaming_callback"] is None + assert init_params["tools"] == [ + { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": {"x": {"type": "string"}}, + } + ] def test_from_dict(self, mock_check_valid_model): - generator = HuggingFaceAPIGenerator( + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + + generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, token=Secret.from_env_var("ENV_VAR", strict=False), generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], - streaming_callback=streaming_callback_handler, + tools=[tool], ) result = generator.to_dict() # now deserialize, call from_dict - generator_2 = HuggingFaceAPIGenerator.from_dict(result) + generator_2 = HuggingFaceAPIChatGenerator.from_dict(result) assert generator_2.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API assert generator_2.api_params == {"model": "HuggingFaceH4/zephyr-7b-beta"} assert generator_2.token == Secret.from_env_var("ENV_VAR", strict=False) - assert generator_2.generation_kwargs == { - "temperature": 0.6, - "stop_sequences": ["stop", "words"], - "max_new_tokens": 512, - } - assert generator_2.streaming_callback is streaming_callback_handler + assert generator_2.generation_kwargs == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} + assert generator_2.streaming_callback is None + assert generator_2.tools == [tool] - def test_generate_text_response_with_valid_prompt_and_generation_parameters( - self, mock_check_valid_model, mock_text_generation - ): - generator = HuggingFaceAPIGenerator( + def test_serde_in_pipeline(self, mock_check_valid_model): + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + + generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, token=Secret.from_env_var("ENV_VAR", strict=False), generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], - streaming_callback=None, + tools=[tool], ) - prompt = "Hello, how are you?" - response = generator.run(prompt) - - # check kwargs passed to text_generation - _, kwargs = mock_text_generation.call_args - assert kwargs == { - "details": True, - "temperature": 0.6, - "stop_sequences": ["stop", "words"], - "stream": False, - "max_new_tokens": 512, + pipeline = Pipeline() + pipeline.add_component("generator", generator) + + pipeline_dict = pipeline.to_dict() + assert pipeline_dict == { + "metadata": {}, + "max_runs_per_component": 100, + "components": { + "generator": { + "type": "haystack.components.generators.chat.hugging_face_api.HuggingFaceAPIChatGenerator", + "init_parameters": { + "api_type": "serverless_inference_api", + "api_params": {"model": "HuggingFaceH4/zephyr-7b-beta"}, + "token": {"type": "env_var", "env_vars": ["ENV_VAR"], "strict": False}, + "generation_kwargs": {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}, + "streaming_callback": None, + "tools": [ + { + "name": "name", + "description": "description", + "parameters": {"x": {"type": "string"}}, + "function": "builtins.print", + } + ], + }, + } + }, + "connections": [], } - assert isinstance(response, dict) - assert "replies" in response - assert "meta" in response - assert isinstance(response["replies"], list) - assert isinstance(response["meta"], list) - assert len(response["replies"]) == 1 - assert len(response["meta"]) == 1 - assert [isinstance(reply, str) for reply in response["replies"]] + pipeline_yaml = pipeline.dumps() - def test_generate_text_with_custom_generation_parameters(self, mock_check_valid_model, mock_text_generation): - generator = HuggingFaceAPIGenerator( - api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"} + new_pipeline = Pipeline.loads(pipeline_yaml) + assert new_pipeline == pipeline + + def test_run(self, mock_check_valid_model, mock_chat_completion, chat_messages): + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, + generation_kwargs={"temperature": 0.6}, + stop_words=["stop", "words"], + streaming_callback=None, ) - generation_kwargs = {"temperature": 0.8, "max_new_tokens": 100} - response = generator.run("How are you?", generation_kwargs=generation_kwargs) + response = generator.run(messages=chat_messages) - # check kwargs passed to text_generation - _, kwargs = mock_text_generation.call_args + # check kwargs passed to chat_completion + _, kwargs = mock_chat_completion.call_args + hf_messages = [ + {"role": "system", "content": "You are a helpful assistant speaking A2 level of English"}, + {"role": "user", "content": "Tell me about Berlin"}, + ] assert kwargs == { - "details": True, - "max_new_tokens": 100, - "stop_sequences": [], - "stream": False, - "temperature": 0.8, + "temperature": 0.6, + "stop": ["stop", "words"], + "max_tokens": 512, + "tools": None, + "messages": hf_messages, } - # Assert that the response contains the generated replies and the right response + assert isinstance(response, dict) assert "replies" in response assert isinstance(response["replies"], list) - assert len(response["replies"]) > 0 - assert [isinstance(reply, str) for reply in response["replies"]] - assert response["replies"][0] == "I'm fine, thanks." - - # Assert that the response contains the metadata - assert "meta" in response - assert isinstance(response["meta"], list) - assert len(response["meta"]) > 0 - assert [isinstance(reply, str) for reply in response["replies"]] + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] - def test_generate_text_with_streaming_callback(self, mock_check_valid_model, mock_text_generation): + def test_run_with_streaming_callback(self, mock_check_valid_model, mock_chat_completion, chat_messages): streaming_call_count = 0 # Define the streaming callback function @@ -239,38 +333,50 @@ def streaming_callback_fn(chunk: StreamingChunk): streaming_call_count += 1 assert isinstance(chunk, StreamingChunk) - generator = HuggingFaceAPIGenerator( + generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, - api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, streaming_callback=streaming_callback_fn, ) # Create a fake streamed response - # Don't remove self + # self needed here, don't remove def mock_iter(self): - yield TextGenerationStreamOutput( - index=0, - generated_text=None, - token=TextGenerationOutputToken(id=1, text="I'm fine, thanks.", logprob=0.0, special=False), + yield ChatCompletionStreamOutput( + choices=[ + ChatCompletionStreamOutputChoice( + delta=ChatCompletionStreamOutputDelta(content="The", role="assistant"), + index=0, + finish_reason=None, + ) + ], + id="some_id", + model="some_model", + system_fingerprint="some_fingerprint", + created=1710498504, ) - yield TextGenerationStreamOutput( - index=1, - generated_text=None, - token=TextGenerationOutputToken(id=1, text="Ok bye", logprob=0.0, special=False), - details=TextGenerationStreamOutputStreamDetails( - finish_reason="length", generated_tokens=5, seed=None, input_length=10 - ), + + yield ChatCompletionStreamOutput( + choices=[ + ChatCompletionStreamOutputChoice( + delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length" + ) + ], + id="some_id", + model="some_model", + system_fingerprint="some_fingerprint", + created=1710498504, ) mock_response = Mock(**{"__iter__": mock_iter}) - mock_text_generation.return_value = mock_response + mock_chat_completion.return_value = mock_response # Generate text response with streaming callback - response = generator.run("prompt") + response = generator.run(chat_messages) # check kwargs passed to text_generation - _, kwargs = mock_text_generation.call_args - assert kwargs == {"details": True, "stop_sequences": [], "stream": True, "max_new_tokens": 512} + _, kwargs = mock_chat_completion.call_args + assert kwargs == {"stop": [], "stream": True, "max_tokens": 512} # Assert that the streaming callback was called twice assert streaming_call_count == 2 @@ -279,36 +385,161 @@ def mock_iter(self): assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) > 0 - assert [isinstance(reply, str) for reply in response["replies"]] + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + def test_run_fail_with_tools_and_streaming(self, tools, mock_check_valid_model): + component = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, + streaming_callback=streaming_callback_handler, + ) + + with pytest.raises(ValueError): + message = ChatMessage.from_user("irrelevant") + component.run([message], tools=tools) + + def test_run_with_tools(self, mock_check_valid_model, tools): + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "meta-llama/Llama-3.1-70B-Instruct"}, + tools=tools, + ) - # Assert that the response contains the metadata - assert "meta" in response - assert isinstance(response["meta"], list) - assert len(response["meta"]) > 0 - assert [isinstance(meta, dict) for meta in response["meta"]] + with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion: + completion = ChatCompletionOutput( + choices=[ + ChatCompletionOutputComplete( + finish_reason="stop", + index=0, + message=ChatCompletionOutputMessage( + role="assistant", + content=None, + tool_calls=[ + ChatCompletionOutputToolCall( + function=ChatCompletionOutputFunctionDefinition( + arguments={"city": "Paris"}, name="weather", description=None + ), + id="0", + type="function", + ) + ], + ), + logprobs=None, + ) + ], + created=1729074760, + id="", + model="meta-llama/Llama-3.1-70B-Instruct", + system_fingerprint="2.3.2-dev0-sha-28bb7ae", + usage=ChatCompletionOutputUsage(completion_tokens=30, prompt_tokens=426, total_tokens=456), + ) + mock_chat_completion.return_value = completion + + messages = [ChatMessage.from_user("What is the weather in Paris?")] + response = generator.run(messages=messages) + + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + assert response["replies"][0].tool_calls[0].tool_name == "weather" + assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"} + assert response["replies"][0].tool_calls[0].id == "0" + assert response["replies"][0].meta == { + "finish_reason": "stop", + "index": 0, + "model": "meta-llama/Llama-3.1-70B-Instruct", + "usage": {"completion_tokens": 30, "prompt_tokens": 426}, + } - @pytest.mark.flaky(reruns=5, reruns_delay=5) @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("HF_API_TOKEN", None), reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", ) - def test_run_serverless(self): - generator = HuggingFaceAPIGenerator( + def test_live_run_serverless(self): + generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, - generation_kwargs={"max_new_tokens": 20}, + generation_kwargs={"max_tokens": 20}, ) - response = generator.run("How are you?") - # Assert that the response contains the generated replies + messages = [ChatMessage.from_user("What is the capital of France?")] + response = generator.run(messages=messages) + + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) > 0 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + assert "usage" in response["replies"][0].meta + assert "prompt_tokens" in response["replies"][0].meta["usage"] + assert "completion_tokens" in response["replies"][0].meta["usage"] + + @pytest.mark.integration + @pytest.mark.skipif( + not os.environ.get("HF_API_TOKEN", None), + reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", + ) + def test_live_run_serverless_streaming(self): + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + generation_kwargs={"max_tokens": 20}, + streaming_callback=streaming_callback_handler, + ) + + messages = [ChatMessage.from_user("What is the capital of France?")] + response = generator.run(messages=messages) + assert "replies" in response assert isinstance(response["replies"], list) assert len(response["replies"]) > 0 - assert [isinstance(reply, str) for reply in response["replies"]] + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + assert "usage" in response["replies"][0].meta + assert "prompt_tokens" in response["replies"][0].meta["usage"] + assert "completion_tokens" in response["replies"][0].meta["usage"] + + @pytest.mark.integration + @pytest.mark.skipif( + not os.environ.get("HF_API_TOKEN", None), + reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", + ) + @pytest.mark.integration + def test_live_run_with_tools(self, tools): + """ + We test the round trip: generate tool call, pass tool message, generate response. + + The model used here (zephyr-7b-beta) is always available and not gated. + Even if it does not officially support tools, TGI+HF API make it work. + """ + + chat_messages = [ChatMessage.from_user("What's the weather like in Paris and Munich?")] + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + generation_kwargs={"temperature": 0.5}, + ) + + results = generator.run(chat_messages, tools=tools) + assert len(results["replies"]) == 1 + message = results["replies"][0] + + assert message.tool_calls + tool_call = message.tool_call + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert "city" in tool_call.arguments + assert "Paris" in tool_call.arguments["city"] + assert message.meta["finish_reason"] == "stop" + + new_messages = chat_messages + [message, ChatMessage.from_tool(tool_result="22° C", origin=tool_call)] + + # the model tends to make tool calls if provided with tools, so we don't pass them here + results = generator.run(new_messages, generation_kwargs={"max_tokens": 50}) - # Assert that the response contains the metadata - assert "meta" in response - assert isinstance(response["meta"], list) - assert len(response["meta"]) > 0 - assert [isinstance(meta, dict) for meta in response["meta"]] + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_calls + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() diff --git a/test/dataclasses/test_tool.py b/test/dataclasses/test_tool.py index db9719a7f3..9e112853f3 100644 --- a/test/dataclasses/test_tool.py +++ b/test/dataclasses/test_tool.py @@ -12,6 +12,7 @@ ToolInvocationError, _remove_title_from_schema, deserialize_tools_inplace, + _check_duplicate_tool_names, ) try: @@ -303,3 +304,18 @@ def test_remove_title_from_schema_handle_no_title_in_top_level(): "properties": {"parameter1": {"type": "string"}, "parameter2": {"type": "integer"}}, "type": "object", } + + +def test_check_duplicate_tool_names(): + tools = [ + Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report), + Tool(name="weather", description="A different description", parameters=parameters, function=get_weather_report), + ] + with pytest.raises(ValueError): + _check_duplicate_tool_names(tools) + + tools = [ + Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report), + Tool(name="weather2", description="Get weather report", parameters=parameters, function=get_weather_report), + ] + _check_duplicate_tool_names(tools)