From 7dcbf25bd73fd8e44819a972b72fc7cf55b4bfca Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 20 Dec 2024 14:02:42 +0100 Subject: [PATCH] feat: add Tool Invoker component (#8664) * port toolinvoker * release note --- docs/pydoc/config/tools_api.yml | 27 ++ haystack/components/tools/__init__.py | 7 + haystack/components/tools/tool_invoker.py | 246 ++++++++++++++++++ .../add-tool-invoker-3bc567b27aec2f32.yaml | 6 + test/components/tools/__init__.py | 3 + test/components/tools/test_tool_invoker.py | 220 ++++++++++++++++ 6 files changed, 509 insertions(+) create mode 100644 docs/pydoc/config/tools_api.yml create mode 100644 haystack/components/tools/__init__.py create mode 100644 haystack/components/tools/tool_invoker.py create mode 100644 releasenotes/notes/add-tool-invoker-3bc567b27aec2f32.yaml create mode 100644 test/components/tools/__init__.py create mode 100644 test/components/tools/test_tool_invoker.py diff --git a/docs/pydoc/config/tools_api.yml b/docs/pydoc/config/tools_api.yml new file mode 100644 index 0000000000..449f720f03 --- /dev/null +++ b/docs/pydoc/config/tools_api.yml @@ -0,0 +1,27 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../../../haystack/components/tools] + modules: ["tool_invoker"] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmeCoreRenderer + excerpt: Components related to Tool Calling. + category_slug: haystack-api + title: Tools + slug: tools-api + order: 152 + markdown: + descriptive_class_title: false + classdef_code_block: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: tools_api.md diff --git a/haystack/components/tools/__init__.py b/haystack/components/tools/__init__.py new file mode 100644 index 0000000000..487097a109 --- /dev/null +++ b/haystack/components/tools/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from haystack.components.tools.tool_invoker import ToolInvoker + +_all_ = ["ToolInvoker"] diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py new file mode 100644 index 0000000000..13d556e279 --- /dev/null +++ b/haystack/components/tools/tool_invoker.py @@ -0,0 +1,246 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +import warnings +from typing import Any, Dict, List + +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.dataclasses.chat_message import ChatMessage, ToolCall +from haystack.dataclasses.tool import Tool, ToolInvocationError, _check_duplicate_tool_names, deserialize_tools_inplace + +logger = logging.getLogger(__name__) + +_TOOL_INVOCATION_FAILURE = "Tool invocation failed with error: {error}." +_TOOL_NOT_FOUND = "Tool {tool_name} not found in the list of tools. Available tools are: {available_tools}." +_TOOL_RESULT_CONVERSION_FAILURE = ( + "Failed to convert tool result to string using '{conversion_function}'. Error: {error}." +) + + +class ToolNotFoundException(Exception): + """ + Exception raised when a tool is not found in the list of available tools. + """ + + pass + + +class StringConversionError(Exception): + """ + Exception raised when the conversion of a tool result to a string fails. + """ + + pass + + +@component +class ToolInvoker: + """ + Invokes tools based on prepared tool calls and returns the results as a list of ChatMessage objects. + + At initialization, the ToolInvoker component is provided with a list of available tools. + At runtime, the component processes a list of ChatMessage object containing tool calls + and invokes the corresponding tools. + The results of the tool invocations are returned as a list of ChatMessage objects with tool role. + + Usage example: + ```python + from haystack.dataclasses import ChatMessage, ToolCall, Tool + from haystack.components.tools import ToolInvoker + + # Tool definition + def dummy_weather_function(city: str): + return f"The weather in {city} is 20 degrees." + + parameters = {"type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"]} + + tool = Tool(name="weather_tool", + description="A tool to get the weather", + function=dummy_weather_function, + parameters=parameters) + + # Usually, the ChatMessage with tool_calls is generated by a Language Model + # Here, we create it manually for demonstration purposes + tool_call = ToolCall( + tool_name="weather_tool", + arguments={"city": "Berlin"} + ) + message = ChatMessage.from_assistant(tool_calls=[tool_call]) + + # ToolInvoker initialization and run + invoker = ToolInvoker(tools=[tool]) + result = invoker.run(messages=[message]) + + print(result) + ``` + + ``` + >> { + >> 'tool_messages': [ + >> ChatMessage( + >> _role=, + >> _content=[ + >> ToolCallResult( + >> result='"The weather in Berlin is 20 degrees."', + >> origin=ToolCall( + >> tool_name='weather_tool', + >> arguments={'city': 'Berlin'}, + >> id=None + >> ) + >> ) + >> ], + >> _meta={} + >> ) + >> ] + >> } + ``` + """ + + def __init__(self, tools: List[Tool], raise_on_failure: bool = True, convert_result_to_json_string: bool = False): + """ + Initialize the ToolInvoker component. + + :param tools: + A list of tools that can be invoked. + :param raise_on_failure: + If True, the component will raise an exception in case of errors + (tool not found, tool invocation errors, tool result conversion errors). + If False, the component will return a ChatMessage object with `error=True` + and a description of the error in `result`. + :param convert_result_to_json_string: + If True, the tool invocation result will be converted to a string using `json.dumps`. + If False, the tool invocation result will be converted to a string using `str`. + + :raises ValueError: + If no tools are provided or if duplicate tool names are found. + """ + + msg = "The `ToolInvoker` component is experimental and its API may change in the future." + warnings.warn(msg) + + if not tools: + raise ValueError("ToolInvoker requires at least one tool to be provided.") + _check_duplicate_tool_names(tools) + + self.tools = tools + self._tools_with_names = dict(zip([tool.name for tool in tools], tools)) + self.raise_on_failure = raise_on_failure + self.convert_result_to_json_string = convert_result_to_json_string + + def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall) -> ChatMessage: + """ + Prepares a ChatMessage with the result of a tool invocation. + + :param result: + The tool result. + :returns: + A ChatMessage object containing the tool result as a string. + + :raises + StringConversionError: If the conversion of the tool result to a string fails + and `raise_on_failure` is True. + """ + error = False + + if self.convert_result_to_json_string: + try: + # We disable ensure_ascii so special chars like emojis are not converted + tool_result_str = json.dumps(result, ensure_ascii=False) + except Exception as e: + if self.raise_on_failure: + raise StringConversionError("Failed to convert tool result to string using `json.dumps`") from e + tool_result_str = _TOOL_RESULT_CONVERSION_FAILURE.format(error=e, conversion_function="json.dumps") + error = True + return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call) + + try: + tool_result_str = str(result) + except Exception as e: + if self.raise_on_failure: + raise StringConversionError("Failed to convert tool result to string using `str`") from e + tool_result_str = _TOOL_RESULT_CONVERSION_FAILURE.format(error=e, conversion_function="str") + error = True + return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call) + + @component.output_types(tool_messages=List[ChatMessage]) + def run(self, messages: List[ChatMessage]) -> Dict[str, Any]: + """ + Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available. + + :param messages: + A list of ChatMessage objects. + :returns: + A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role. + Each ChatMessage objects wraps the result of a tool invocation. + + :raises ToolNotFoundException: + If the tool is not found in the list of available tools and `raise_on_failure` is True. + :raises ToolInvocationError: + If the tool invocation fails and `raise_on_failure` is True. + :raises StringConversionError: + If the conversion of the tool result to a string fails and `raise_on_failure` is True. + """ + tool_messages = [] + + for message in messages: + tool_calls = message.tool_calls + if not tool_calls: + continue + + for tool_call in tool_calls: + tool_name = tool_call.tool_name + tool_arguments = tool_call.arguments + + if not tool_name in self._tools_with_names: + msg = _TOOL_NOT_FOUND.format(tool_name=tool_name, available_tools=self._tools_with_names.keys()) + if self.raise_on_failure: + raise ToolNotFoundException(msg) + tool_messages.append(ChatMessage.from_tool(tool_result=msg, origin=tool_call, error=True)) + continue + + tool_to_invoke = self._tools_with_names[tool_name] + try: + tool_result = tool_to_invoke.invoke(**tool_arguments) + except ToolInvocationError as e: + if self.raise_on_failure: + raise e + msg = _TOOL_INVOCATION_FAILURE.format(error=e) + tool_messages.append(ChatMessage.from_tool(tool_result=msg, origin=tool_call, error=True)) + continue + + tool_message = self._prepare_tool_result_message(tool_result, tool_call) + tool_messages.append(tool_message) + + return {"tool_messages": tool_messages} + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + serialized_tools = [tool.to_dict() for tool in self.tools] + return default_to_dict( + self, + tools=serialized_tools, + raise_on_failure=self.raise_on_failure, + convert_result_to_json_string=self.convert_result_to_json_string, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ToolInvoker": + """ + Deserializes the component from a dictionary. + + :param data: + The dictionary to deserialize from. + :returns: + The deserialized component. + """ + deserialize_tools_inplace(data["init_parameters"], key="tools") + return default_from_dict(cls, data) diff --git a/releasenotes/notes/add-tool-invoker-3bc567b27aec2f32.yaml b/releasenotes/notes/add-tool-invoker-3bc567b27aec2f32.yaml new file mode 100644 index 0000000000..6c1f52344f --- /dev/null +++ b/releasenotes/notes/add-tool-invoker-3bc567b27aec2f32.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Add a new experimental component `ToolInvoker`. + This component invokes tools based on tool calls prepared by Language Models and returns the results as a list of + ChatMessage objects with tool role. diff --git a/test/components/tools/__init__.py b/test/components/tools/__init__.py new file mode 100644 index 0000000000..c1764a6e03 --- /dev/null +++ b/test/components/tools/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/test/components/tools/test_tool_invoker.py b/test/components/tools/test_tool_invoker.py new file mode 100644 index 0000000000..34b1ca9fef --- /dev/null +++ b/test/components/tools/test_tool_invoker.py @@ -0,0 +1,220 @@ +import pytest +import datetime + +from haystack import Pipeline + +from haystack.dataclasses import ChatMessage, ToolCall, ToolCallResult, ChatRole +from haystack.dataclasses.tool import Tool, ToolInvocationError +from haystack.components.tools.tool_invoker import ToolInvoker, ToolNotFoundException, StringConversionError + + +def weather_function(location): + weather_info = { + "Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"}, + "Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"}, + "Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"}, + } + return weather_info.get(location, {"weather": "unknown", "temperature": 0, "unit": "celsius"}) + + +weather_parameters = {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]} + + +@pytest.fixture +def weather_tool(): + return Tool( + name="weather_tool", + description="Provides weather information for a given location.", + parameters=weather_parameters, + function=weather_function, + ) + + +@pytest.fixture +def faulty_tool(): + def faulty_tool_func(location): + raise Exception("This tool always fails.") + + faulty_tool_parameters = { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + } + + return Tool( + name="faulty_tool", + description="A tool that always fails when invoked.", + parameters=faulty_tool_parameters, + function=faulty_tool_func, + ) + + +@pytest.fixture +def invoker(weather_tool): + return ToolInvoker(tools=[weather_tool], raise_on_failure=True, convert_result_to_json_string=False) + + +@pytest.fixture +def faulty_invoker(faulty_tool): + return ToolInvoker(tools=[faulty_tool], raise_on_failure=True, convert_result_to_json_string=False) + + +class TestToolInvoker: + def test_init(self, weather_tool): + invoker = ToolInvoker(tools=[weather_tool]) + + assert invoker.tools == [weather_tool] + assert invoker._tools_with_names == {"weather_tool": weather_tool} + assert invoker.raise_on_failure + assert not invoker.convert_result_to_json_string + + def test_init_fails_wo_tools(self): + with pytest.raises(ValueError): + ToolInvoker(tools=[]) + + def test_init_fails_with_duplicate_tool_names(self, weather_tool, faulty_tool): + with pytest.raises(ValueError): + ToolInvoker(tools=[weather_tool, weather_tool]) + + new_tool = faulty_tool + new_tool.name = "weather_tool" + with pytest.raises(ValueError): + ToolInvoker(tools=[weather_tool, new_tool]) + + def test_run(self, invoker): + tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}) + message = ChatMessage.from_assistant(tool_calls=[tool_call]) + + result = invoker.run(messages=[message]) + assert "tool_messages" in result + assert len(result["tool_messages"]) == 1 + + tool_message = result["tool_messages"][0] + assert isinstance(tool_message, ChatMessage) + assert tool_message.is_from(ChatRole.TOOL) + + assert tool_message.tool_call_results + tool_call_result = tool_message.tool_call_result + + assert isinstance(tool_call_result, ToolCallResult) + assert tool_call_result.result == str({"weather": "mostly sunny", "temperature": 7, "unit": "celsius"}) + assert tool_call_result.origin == tool_call + assert not tool_call_result.error + + def test_run_no_messages(self, invoker): + result = invoker.run(messages=[]) + assert result == {"tool_messages": []} + + def test_run_no_tool_calls(self, invoker): + user_message = ChatMessage.from_user(text="Hello!") + assistant_message = ChatMessage.from_assistant(text="How can I help you?") + + result = invoker.run(messages=[user_message, assistant_message]) + assert result == {"tool_messages": []} + + def test_run_multiple_tool_calls(self, invoker): + tool_calls = [ + ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}), + ToolCall(tool_name="weather_tool", arguments={"location": "Paris"}), + ToolCall(tool_name="weather_tool", arguments={"location": "Rome"}), + ] + message = ChatMessage.from_assistant(tool_calls=tool_calls) + + result = invoker.run(messages=[message]) + assert "tool_messages" in result + assert len(result["tool_messages"]) == 3 + + for i, tool_message in enumerate(result["tool_messages"]): + assert isinstance(tool_message, ChatMessage) + assert tool_message.is_from(ChatRole.TOOL) + + assert tool_message.tool_call_results + tool_call_result = tool_message.tool_call_result + + assert isinstance(tool_call_result, ToolCallResult) + assert not tool_call_result.error + assert tool_call_result.origin == tool_calls[i] + + def test_tool_not_found_error(self, invoker): + tool_call = ToolCall(tool_name="non_existent_tool", arguments={"location": "Berlin"}) + tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call]) + + with pytest.raises(ToolNotFoundException): + invoker.run(messages=[tool_call_message]) + + def test_tool_not_found_does_not_raise_exception(self, invoker): + invoker.raise_on_failure = False + + tool_call = ToolCall(tool_name="non_existent_tool", arguments={"location": "Berlin"}) + tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call]) + + result = invoker.run(messages=[tool_call_message]) + tool_message = result["tool_messages"][0] + + assert tool_message.tool_call_results[0].error + assert "not found" in tool_message.tool_call_results[0].result + + def test_tool_invocation_error(self, faulty_invoker): + tool_call = ToolCall(tool_name="faulty_tool", arguments={"location": "Berlin"}) + tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call]) + + with pytest.raises(ToolInvocationError): + faulty_invoker.run(messages=[tool_call_message]) + + def test_tool_invocation_error_does_not_raise_exception(self, faulty_invoker): + faulty_invoker.raise_on_failure = False + + tool_call = ToolCall(tool_name="faulty_tool", arguments={"location": "Berlin"}) + tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call]) + + result = faulty_invoker.run(messages=[tool_call_message]) + tool_message = result["tool_messages"][0] + assert tool_message.tool_call_results[0].error + assert "invocation failed" in tool_message.tool_call_results[0].result + + def test_string_conversion_error(self, invoker): + invoker.convert_result_to_json_string = True + + tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}) + + tool_result = datetime.datetime.now() + with pytest.raises(StringConversionError): + invoker._prepare_tool_result_message(result=tool_result, tool_call=tool_call) + + def test_string_conversion_error_does_not_raise_exception(self, invoker): + invoker.convert_result_to_json_string = True + invoker.raise_on_failure = False + + tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}) + + tool_result = datetime.datetime.now() + tool_message = invoker._prepare_tool_result_message(result=tool_result, tool_call=tool_call) + + assert tool_message.tool_call_results[0].error + assert "Failed to convert" in tool_message.tool_call_results[0].result + + def test_to_dict(self, invoker, weather_tool): + data = invoker.to_dict() + assert data == { + "type": "haystack.components.tools.tool_invoker.ToolInvoker", + "init_parameters": { + "tools": [weather_tool.to_dict()], + "raise_on_failure": True, + "convert_result_to_json_string": False, + }, + } + + def test_from_dict(self, weather_tool): + data = { + "type": "haystack.components.tools.tool_invoker.ToolInvoker", + "init_parameters": { + "tools": [weather_tool.to_dict()], + "raise_on_failure": True, + "convert_result_to_json_string": False, + }, + } + invoker = ToolInvoker.from_dict(data) + assert invoker.tools == [weather_tool] + assert invoker._tools_with_names == {"weather_tool": weather_tool} + assert invoker.raise_on_failure + assert not invoker.convert_result_to_json_string