-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add Tool Invoker component (#8664)
* port toolinvoker * release note
- Loading branch information
Showing
6 changed files
with
509 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
loaders: | ||
- type: haystack_pydoc_tools.loaders.CustomPythonLoader | ||
search_path: [../../../haystack/components/tools] | ||
modules: ["tool_invoker"] | ||
ignore_when_discovered: ["__init__"] | ||
processors: | ||
- type: filter | ||
expression: | ||
documented_only: true | ||
do_not_filter_modules: false | ||
skip_empty_modules: true | ||
- type: smart | ||
- type: crossref | ||
renderer: | ||
type: haystack_pydoc_tools.renderers.ReadmeCoreRenderer | ||
excerpt: Components related to Tool Calling. | ||
category_slug: haystack-api | ||
title: Tools | ||
slug: tools-api | ||
order: 152 | ||
markdown: | ||
descriptive_class_title: false | ||
classdef_code_block: false | ||
descriptive_module_title: true | ||
add_method_class_prefix: true | ||
add_member_class_prefix: false | ||
filename: tools_api.md |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from haystack.components.tools.tool_invoker import ToolInvoker | ||
|
||
_all_ = ["ToolInvoker"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import json | ||
import warnings | ||
from typing import Any, Dict, List | ||
|
||
from haystack import component, default_from_dict, default_to_dict, logging | ||
from haystack.dataclasses.chat_message import ChatMessage, ToolCall | ||
from haystack.dataclasses.tool import Tool, ToolInvocationError, _check_duplicate_tool_names, deserialize_tools_inplace | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
_TOOL_INVOCATION_FAILURE = "Tool invocation failed with error: {error}." | ||
_TOOL_NOT_FOUND = "Tool {tool_name} not found in the list of tools. Available tools are: {available_tools}." | ||
_TOOL_RESULT_CONVERSION_FAILURE = ( | ||
"Failed to convert tool result to string using '{conversion_function}'. Error: {error}." | ||
) | ||
|
||
|
||
class ToolNotFoundException(Exception): | ||
""" | ||
Exception raised when a tool is not found in the list of available tools. | ||
""" | ||
|
||
pass | ||
|
||
|
||
class StringConversionError(Exception): | ||
""" | ||
Exception raised when the conversion of a tool result to a string fails. | ||
""" | ||
|
||
pass | ||
|
||
|
||
@component | ||
class ToolInvoker: | ||
""" | ||
Invokes tools based on prepared tool calls and returns the results as a list of ChatMessage objects. | ||
At initialization, the ToolInvoker component is provided with a list of available tools. | ||
At runtime, the component processes a list of ChatMessage object containing tool calls | ||
and invokes the corresponding tools. | ||
The results of the tool invocations are returned as a list of ChatMessage objects with tool role. | ||
Usage example: | ||
```python | ||
from haystack.dataclasses import ChatMessage, ToolCall, Tool | ||
from haystack.components.tools import ToolInvoker | ||
# Tool definition | ||
def dummy_weather_function(city: str): | ||
return f"The weather in {city} is 20 degrees." | ||
parameters = {"type": "object", | ||
"properties": {"city": {"type": "string"}}, | ||
"required": ["city"]} | ||
tool = Tool(name="weather_tool", | ||
description="A tool to get the weather", | ||
function=dummy_weather_function, | ||
parameters=parameters) | ||
# Usually, the ChatMessage with tool_calls is generated by a Language Model | ||
# Here, we create it manually for demonstration purposes | ||
tool_call = ToolCall( | ||
tool_name="weather_tool", | ||
arguments={"city": "Berlin"} | ||
) | ||
message = ChatMessage.from_assistant(tool_calls=[tool_call]) | ||
# ToolInvoker initialization and run | ||
invoker = ToolInvoker(tools=[tool]) | ||
result = invoker.run(messages=[message]) | ||
print(result) | ||
``` | ||
``` | ||
>> { | ||
>> 'tool_messages': [ | ||
>> ChatMessage( | ||
>> _role=<ChatRole.TOOL: 'tool'>, | ||
>> _content=[ | ||
>> ToolCallResult( | ||
>> result='"The weather in Berlin is 20 degrees."', | ||
>> origin=ToolCall( | ||
>> tool_name='weather_tool', | ||
>> arguments={'city': 'Berlin'}, | ||
>> id=None | ||
>> ) | ||
>> ) | ||
>> ], | ||
>> _meta={} | ||
>> ) | ||
>> ] | ||
>> } | ||
``` | ||
""" | ||
|
||
def __init__(self, tools: List[Tool], raise_on_failure: bool = True, convert_result_to_json_string: bool = False): | ||
""" | ||
Initialize the ToolInvoker component. | ||
:param tools: | ||
A list of tools that can be invoked. | ||
:param raise_on_failure: | ||
If True, the component will raise an exception in case of errors | ||
(tool not found, tool invocation errors, tool result conversion errors). | ||
If False, the component will return a ChatMessage object with `error=True` | ||
and a description of the error in `result`. | ||
:param convert_result_to_json_string: | ||
If True, the tool invocation result will be converted to a string using `json.dumps`. | ||
If False, the tool invocation result will be converted to a string using `str`. | ||
:raises ValueError: | ||
If no tools are provided or if duplicate tool names are found. | ||
""" | ||
|
||
msg = "The `ToolInvoker` component is experimental and its API may change in the future." | ||
warnings.warn(msg) | ||
|
||
if not tools: | ||
raise ValueError("ToolInvoker requires at least one tool to be provided.") | ||
_check_duplicate_tool_names(tools) | ||
|
||
self.tools = tools | ||
self._tools_with_names = dict(zip([tool.name for tool in tools], tools)) | ||
self.raise_on_failure = raise_on_failure | ||
self.convert_result_to_json_string = convert_result_to_json_string | ||
|
||
def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall) -> ChatMessage: | ||
""" | ||
Prepares a ChatMessage with the result of a tool invocation. | ||
:param result: | ||
The tool result. | ||
:returns: | ||
A ChatMessage object containing the tool result as a string. | ||
:raises | ||
StringConversionError: If the conversion of the tool result to a string fails | ||
and `raise_on_failure` is True. | ||
""" | ||
error = False | ||
|
||
if self.convert_result_to_json_string: | ||
try: | ||
# We disable ensure_ascii so special chars like emojis are not converted | ||
tool_result_str = json.dumps(result, ensure_ascii=False) | ||
except Exception as e: | ||
if self.raise_on_failure: | ||
raise StringConversionError("Failed to convert tool result to string using `json.dumps`") from e | ||
tool_result_str = _TOOL_RESULT_CONVERSION_FAILURE.format(error=e, conversion_function="json.dumps") | ||
error = True | ||
return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call) | ||
|
||
try: | ||
tool_result_str = str(result) | ||
except Exception as e: | ||
if self.raise_on_failure: | ||
raise StringConversionError("Failed to convert tool result to string using `str`") from e | ||
tool_result_str = _TOOL_RESULT_CONVERSION_FAILURE.format(error=e, conversion_function="str") | ||
error = True | ||
return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call) | ||
|
||
@component.output_types(tool_messages=List[ChatMessage]) | ||
def run(self, messages: List[ChatMessage]) -> Dict[str, Any]: | ||
""" | ||
Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available. | ||
:param messages: | ||
A list of ChatMessage objects. | ||
:returns: | ||
A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role. | ||
Each ChatMessage objects wraps the result of a tool invocation. | ||
:raises ToolNotFoundException: | ||
If the tool is not found in the list of available tools and `raise_on_failure` is True. | ||
:raises ToolInvocationError: | ||
If the tool invocation fails and `raise_on_failure` is True. | ||
:raises StringConversionError: | ||
If the conversion of the tool result to a string fails and `raise_on_failure` is True. | ||
""" | ||
tool_messages = [] | ||
|
||
for message in messages: | ||
tool_calls = message.tool_calls | ||
if not tool_calls: | ||
continue | ||
|
||
for tool_call in tool_calls: | ||
tool_name = tool_call.tool_name | ||
tool_arguments = tool_call.arguments | ||
|
||
if not tool_name in self._tools_with_names: | ||
msg = _TOOL_NOT_FOUND.format(tool_name=tool_name, available_tools=self._tools_with_names.keys()) | ||
if self.raise_on_failure: | ||
raise ToolNotFoundException(msg) | ||
tool_messages.append(ChatMessage.from_tool(tool_result=msg, origin=tool_call, error=True)) | ||
continue | ||
|
||
tool_to_invoke = self._tools_with_names[tool_name] | ||
try: | ||
tool_result = tool_to_invoke.invoke(**tool_arguments) | ||
except ToolInvocationError as e: | ||
if self.raise_on_failure: | ||
raise e | ||
msg = _TOOL_INVOCATION_FAILURE.format(error=e) | ||
tool_messages.append(ChatMessage.from_tool(tool_result=msg, origin=tool_call, error=True)) | ||
continue | ||
|
||
tool_message = self._prepare_tool_result_message(tool_result, tool_call) | ||
tool_messages.append(tool_message) | ||
|
||
return {"tool_messages": tool_messages} | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
Serializes the component to a dictionary. | ||
:returns: | ||
Dictionary with serialized data. | ||
""" | ||
serialized_tools = [tool.to_dict() for tool in self.tools] | ||
return default_to_dict( | ||
self, | ||
tools=serialized_tools, | ||
raise_on_failure=self.raise_on_failure, | ||
convert_result_to_json_string=self.convert_result_to_json_string, | ||
) | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "ToolInvoker": | ||
""" | ||
Deserializes the component from a dictionary. | ||
:param data: | ||
The dictionary to deserialize from. | ||
:returns: | ||
The deserialized component. | ||
""" | ||
deserialize_tools_inplace(data["init_parameters"], key="tools") | ||
return default_from_dict(cls, data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
--- | ||
features: | ||
- | | ||
Add a new experimental component `ToolInvoker`. | ||
This component invokes tools based on tool calls prepared by Language Models and returns the results as a list of | ||
ChatMessage objects with tool role. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 |
Oops, something went wrong.