Skip to content

Commit

Permalink
hfapi w tools
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Dec 19, 2024
1 parent 4db6f40 commit 44b0103
Show file tree
Hide file tree
Showing 5 changed files with 502 additions and 180 deletions.
150 changes: 106 additions & 44 deletions haystack/components/generators/chat/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,25 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
from haystack.dataclasses.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model, convert_message_to_hf_format
from haystack.utils.url_validation import is_valid_http_url

with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.23.0\"'") as huggingface_hub_import:
from huggingface_hub import ChatCompletionOutput, ChatCompletionStreamOutput, InferenceClient
from huggingface_hub import (
ChatCompletionInputTool,
ChatCompletionOutput,
ChatCompletionStreamOutput,
InferenceClient,
)


logger = logging.getLogger(__name__)


def _convert_message_to_hfapi_format(message: ChatMessage) -> Dict[str, str]:
"""
Convert a message to the format expected by Hugging Face APIs.
:returns: A dictionary with the following keys:
- `role`
- `content`
"""
return {"role": message.role.value, "content": message.text or ""}


@component
class HuggingFaceAPIChatGenerator:
"""
Expand Down Expand Up @@ -107,6 +102,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
generation_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
tools: Optional[List[Tool]] = None,
):
"""
Initialize the HuggingFaceAPIChatGenerator instance.
Expand All @@ -121,14 +117,22 @@ def __init__( # pylint: disable=too-many-positional-arguments
- `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
- `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
`TEXT_GENERATION_INFERENCE`.
:param token: The Hugging Face token to use as HTTP bearer authorization.
:param token:
The Hugging Face token to use as HTTP bearer authorization.
Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
:param generation_kwargs:
A dictionary with keyword arguments to customize text generation.
Some examples: `max_tokens`, `temperature`, `top_p`.
For details, see [Hugging Face chat_completion documentation](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion).
:param stop_words: An optional list of strings representing the stop words.
:param streaming_callback: An optional callable for handling streaming responses.
:param stop_words:
An optional list of strings representing the stop words.
:param streaming_callback:
An optional callable for handling streaming responses.
:param tools:
A list of tools for which the model can prepare calls.
The chosen model should support tool/function calling, according to the model card.
Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience
unexpected behavior.
"""

huggingface_hub_import.check()
Expand Down Expand Up @@ -159,6 +163,11 @@ def __init__( # pylint: disable=too-many-positional-arguments
msg = f"Unknown api_type {api_type}"
raise ValueError(msg)

if tools:
if streaming_callback is not None:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)

# handle generation kwargs setup
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
generation_kwargs["stop"] = generation_kwargs.get("stop", [])
Expand All @@ -171,6 +180,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
self.generation_kwargs = generation_kwargs
self.streaming_callback = streaming_callback
self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
self.tools = tools

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -180,13 +190,15 @@ def to_dict(self) -> Dict[str, Any]:
A dictionary containing the serialized component.
"""
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
return default_to_dict(
self,
api_type=str(self.api_type),
api_params=self.api_params,
token=self.token.to_dict() if self.token else None,
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
tools=serialized_tools,
)

@classmethod
Expand All @@ -195,32 +207,53 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator":
Deserialize this component from a dictionary.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
deserialize_tools_inplace(data["init_parameters"], key="tools")
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
return default_from_dict(cls, data)

@component.output_types(replies=List[ChatMessage])
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
def run(
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
tools: Optional[List[Tool]] = None,
):
"""
Invoke the text generation inference based on the provided messages and generation parameters.
:param messages: A list of ChatMessage objects representing the input messages.
:param generation_kwargs: Additional keyword arguments for text generation.
:param messages:
A list of ChatMessage objects representing the input messages.
:param generation_kwargs:
Additional keyword arguments for text generation.
:param tools:
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
during component initialization.
:returns: A dictionary with the following keys:
- `replies`: A list containing the generated responses as ChatMessage objects.
"""

# update generation kwargs by merging with the default ones
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}

formatted_messages = [_convert_message_to_hfapi_format(message) for message in messages]
formatted_messages = [convert_message_to_hf_format(message) for message in messages]

tools = tools or self.tools
if tools:
if self.streaming_callback:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)

if self.streaming_callback:
return self._run_streaming(formatted_messages, generation_kwargs)

return self._run_non_streaming(formatted_messages, generation_kwargs)
hf_tools = None
if tools:
hf_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools]

return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)

def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]):
api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
Expand All @@ -229,11 +262,17 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict

generated_text = ""

for chunk in api_output: # pylint: disable=not-an-iterable
text = chunk.choices[0].delta.content
for chunk in api_output:
# n is unused, so the API always returns only one choice
# the argument is probably allowed for compatibility with OpenAI
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
choice = chunk.choices[0]

text = choice.delta.content
if text:
generated_text += text
finish_reason = chunk.choices[0].finish_reason

finish_reason = choice.finish_reason

meta = {}
if finish_reason:
Expand All @@ -242,33 +281,56 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict
stream_chunk = StreamingChunk(text, meta)
self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)

message = ChatMessage.from_assistant(generated_text)
message.meta.update(
meta.update(
{
"model": self._client.model,
"finish_reason": finish_reason,
"index": 0,
"usage": {"prompt_tokens": 0, "completion_tokens": 0}, # not available in streaming
}
)

message = ChatMessage.from_assistant(text=generated_text, meta=meta)

return {"replies": [message]}

def _run_non_streaming(
self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]
self,
messages: List[Dict[str, str]],
generation_kwargs: Dict[str, Any],
tools: Optional[List["ChatCompletionInputTool"]] = None,
) -> Dict[str, List[ChatMessage]]:
chat_messages: List[ChatMessage] = []

api_chat_output: ChatCompletionOutput = self._client.chat_completion(messages, **generation_kwargs)
for choice in api_chat_output.choices:
message = ChatMessage.from_assistant(choice.message.content)
message.meta.update(
{
"model": self._client.model,
"finish_reason": choice.finish_reason,
"index": choice.index,
"usage": api_chat_output.usage or {"prompt_tokens": 0, "completion_tokens": 0},
}
)
chat_messages.append(message)

return {"replies": chat_messages}
api_chat_output: ChatCompletionOutput = self._client.chat_completion(
messages=messages, tools=tools, **generation_kwargs
)

if len(api_chat_output.choices) == 0:
return {"replies": []}

# n is unused, so the API always returns only one choice
# the argument is probably allowed for compatibility with OpenAI
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
choice = api_chat_output.choices[0]

text = choice.message.content
tool_calls = []

if hfapi_tool_calls := choice.message.tool_calls:
for hfapi_tc in hfapi_tool_calls:
tool_call = ToolCall(
tool_name=hfapi_tc.function.name, arguments=hfapi_tc.function.arguments, id=hfapi_tc.id
)
tool_calls.append(tool_call)

meta = {
"model": self._client.model,
"finish_reason": choice.finish_reason,
"index": choice.index,
"usage": {
"prompt_tokens": api_chat_output.usage.prompt_tokens,
"completion_tokens": api_chat_output.usage.completion_tokens,
},
}

message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
return {"replies": [message]}
15 changes: 14 additions & 1 deletion haystack/dataclasses/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import inspect
from dataclasses import asdict, dataclass
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, List, Optional

from pydantic import create_model

Expand Down Expand Up @@ -216,6 +216,19 @@ def _remove_title_from_schema(schema: Dict[str, Any]):
del property_schema[key]


def _check_duplicate_tool_names(tools: List[Tool]) -> None:
"""
Check for duplicate tool names.
:param tools: The list of tools to check.
:raises ValueError: If duplicate tool names are found.
"""
tool_names = [tool.name for tool in tools]
duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1}
if duplicate_tool_names:
raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}")


def deserialize_tools_inplace(data: Dict[str, Any], key: str = "tools"):
"""
Deserialize Tools in a dictionary inplace.
Expand Down
2 changes: 1 addition & 1 deletion haystack/utils/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def convert_message_to_hf_format(message: ChatMessage) -> Dict[str, Any]:

if not text_contents and not tool_calls and not tool_call_results:
raise ValueError("A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`.")
elif len(text_contents) + len(tool_call_results) > 1:
if len(text_contents) + len(tool_call_results) > 1:
raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.")

# HF always expects a content field, even if it is empty
Expand Down
Loading

0 comments on commit 44b0103

Please sign in to comment.