From 1c259f9e7a5e1f90970f0f49d5bd7ce2ab628ff5 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 19 Dec 2024 10:57:30 +0100 Subject: [PATCH] Handle our own dataclasses (e.g. Document) --- haystack_experimental/dataclasses/tool.py | 33 ++--- test/components/tools/test_tool_component.py | 147 +++++++++++++++++++ 2 files changed, 160 insertions(+), 20 deletions(-) diff --git a/haystack_experimental/dataclasses/tool.py b/haystack_experimental/dataclasses/tool.py index b4bff232..dfa2e389 100644 --- a/haystack_experimental/dataclasses/tool.py +++ b/haystack_experimental/dataclasses/tool.py @@ -228,30 +228,23 @@ def component_invoker(**kwargs): """ converted_kwargs = {} input_sockets = component.__haystack_input__._sockets_dict - for param_name, param_value in kwargs.items(): - socket = input_sockets[param_name] - param_type = socket.type - origin = get_origin(param_type) or param_type - - if origin is list: - target_type = get_args(param_type)[0] - values_to_convert = param_value + param_type = input_sockets[param_name].type + + # Check if the type (or list element type) has from_dict + target_type = get_args(param_type)[0] if get_origin(param_type) is list else param_type + if hasattr(target_type, "from_dict"): + if isinstance(param_value, list): + param_value = [target_type.from_dict(item) for item in param_value if isinstance(item, dict)] + elif isinstance(param_value, dict): + param_value = target_type.from_dict(param_value) else: - target_type = param_type - values_to_convert = [param_value] - - if isinstance(param_value, dict): - # TypeAdapter handles dict conversion for both dataclasses and Pydantic models - type_adapter = TypeAdapter(target_type) - converted = [ - type_adapter.validate_python(item) for item in values_to_convert if isinstance(item, dict) - ] - param_value = converted if origin is list else converted[0] + # Let TypeAdapter handle both single values and lists + type_adapter = TypeAdapter(param_type) + param_value = type_adapter.validate_python(param_value) converted_kwargs[param_name] = param_value - - logger.debug(f"Invoking component with kwargs: {converted_kwargs}") + logger.debug(f"Invoking component {type(component)} with kwargs: {converted_kwargs}") return component.run(**converted_kwargs) # Return a new Tool instance with the component invoker as the function to be called diff --git a/test/components/tools/test_tool_component.py b/test/components/tools/test_tool_component.py index 7fc89df7..3e0b0685 100644 --- a/test/components/tools/test_tool_component.py +++ b/test/components/tools/test_tool_component.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import json import os import pytest from typing import Dict, List, Optional, Any @@ -9,6 +10,7 @@ from haystack import component from pydantic import BaseModel from haystack import Pipeline +from haystack.dataclasses import Document from haystack_experimental.dataclasses import ChatMessage, ChatRole from haystack_experimental.components.tools.tool_invoker import ToolInvoker from haystack_experimental.components.generators.chat import OpenAIChatGenerator @@ -125,6 +127,21 @@ def run(self, person: Person) -> Dict[str, str]: } +@component +class DocumentProcessor: + """A component that processes a list of Documents.""" + + @component.output_types(concatenated=str) + def run(self, documents: List[Document]) -> Dict[str, str]: + """ + Concatenates the content of multiple documents with newlines. + + :param documents: List of Documents whose content will be concatenated + :returns: Dictionary containing the concatenated document contents + """ + return {"concatenated": '\n'.join(doc.content for doc in documents)} + + ## Unit tests class TestToolComponent: def test_from_component_basic(self): @@ -311,6 +328,104 @@ def test_from_component_with_nested_dataclass(self): assert "info" in result assert result["info"] == "Diana lives at 123 Elm Street, Metropolis." + def test_from_component_with_document_list(self): + component = DocumentProcessor() + + tool = Tool.from_component( + component=component, + name="document_processor", + description="A tool that concatenates document contents" + ) + + assert tool.parameters == { + "type": "object", + "properties": { + "documents": { + "type": "array", + "description": "List of Documents whose content will be concatenated", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Field 'id' of 'Document'." + }, + "content": { + "type": "string", + "description": "Field 'content' of 'Document'." + }, + "dataframe": { + "type": "string", + "description": "Field 'dataframe' of 'Document'." + }, + "blob": { + "type": "object", + "description": "Field 'blob' of 'Document'.", + "properties": { + "data": { + "type": "string", + "description": "Field 'data' of 'ByteStream'." + }, + "meta": { + "type": "string", + "description": "Field 'meta' of 'ByteStream'." + }, + "mime_type": { + "type": "string", + "description": "Field 'mime_type' of 'ByteStream'." + } + }, + "required": ["data"] + }, + "meta": { + "type": "string", + "description": "Field 'meta' of 'Document'." + }, + "score": { + "type": "number", + "description": "Field 'score' of 'Document'." + }, + "embedding": { + "type": "array", + "description": "Field 'embedding' of 'Document'.", + "items": { + "type": "number" + } + }, + "sparse_embedding": { + "type": "object", + "description": "Field 'sparse_embedding' of 'Document'.", + "properties": { + "indices": { + "type": "array", + "description": "Field 'indices' of 'SparseEmbedding'.", + "items": { + "type": "integer" + } + }, + "values": { + "type": "array", + "description": "Field 'values' of 'SparseEmbedding'.", + "items": { + "type": "number" + } + } + }, + "required": ["indices", "values"] + } + } + } + } + }, + "required": ["documents"] + } + + # Test tool invocation + result = tool.invoke(documents=[{"content": "First document"}, {"content": "Second document"}]) + assert isinstance(result, dict) + assert "concatenated" in result + assert result["concatenated"] == "First document\nSecond document" + ## Integration tests class TestToolComponentInPipelineWithOpenAI: @@ -452,6 +567,38 @@ def test_person_processor_in_pipeline(self): assert "Diana" in tool_message.tool_call_result.result and "Metropolis" in tool_message.tool_call_result.result assert not tool_message.tool_call_result.error + @pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_document_processor_in_pipeline(self): + component = DocumentProcessor() + tool = Tool.from_component( + component=component, + name="document_processor", + description="A tool that concatenates the content of multiple documents" + ) + + pipeline = Pipeline() + pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o", tools=[tool])) + pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool], convert_result_to_json_string=True)) + pipeline.connect("llm.replies", "tool_invoker.messages") + + message = ChatMessage.from_user( + text="I have two documents. First one says 'Hello world' and second one says 'Goodbye world'. Can you concatenate them?" + ) + + result = pipeline.run({"llm": {"messages": [message]}}) + + tool_messages = result["tool_invoker"]["tool_messages"] + assert len(tool_messages) == 1 + + tool_message = tool_messages[0] + assert tool_message.is_from(ChatRole.TOOL) + result = json.loads(tool_message.tool_call_result.result) + assert "concatenated" in result + assert "Hello world" in result["concatenated"] + assert "Goodbye world" in result["concatenated"] + assert not tool_message.tool_call_result.error +