Skip to content

Commit

Permalink
Handle our own dataclasses (e.g. Document)
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Dec 19, 2024
1 parent 6341268 commit 1c259f9
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 20 deletions.
33 changes: 13 additions & 20 deletions haystack_experimental/dataclasses/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,30 +228,23 @@ def component_invoker(**kwargs):
"""
converted_kwargs = {}
input_sockets = component.__haystack_input__._sockets_dict

for param_name, param_value in kwargs.items():
socket = input_sockets[param_name]
param_type = socket.type
origin = get_origin(param_type) or param_type

if origin is list:
target_type = get_args(param_type)[0]
values_to_convert = param_value
param_type = input_sockets[param_name].type

# Check if the type (or list element type) has from_dict
target_type = get_args(param_type)[0] if get_origin(param_type) is list else param_type
if hasattr(target_type, "from_dict"):
if isinstance(param_value, list):
param_value = [target_type.from_dict(item) for item in param_value if isinstance(item, dict)]
elif isinstance(param_value, dict):
param_value = target_type.from_dict(param_value)
else:
target_type = param_type
values_to_convert = [param_value]

if isinstance(param_value, dict):
# TypeAdapter handles dict conversion for both dataclasses and Pydantic models
type_adapter = TypeAdapter(target_type)
converted = [
type_adapter.validate_python(item) for item in values_to_convert if isinstance(item, dict)
]
param_value = converted if origin is list else converted[0]
# Let TypeAdapter handle both single values and lists
type_adapter = TypeAdapter(param_type)
param_value = type_adapter.validate_python(param_value)

converted_kwargs[param_name] = param_value

logger.debug(f"Invoking component with kwargs: {converted_kwargs}")
logger.debug(f"Invoking component {type(component)} with kwargs: {converted_kwargs}")
return component.run(**converted_kwargs)

# Return a new Tool instance with the component invoker as the function to be called
Expand Down
147 changes: 147 additions & 0 deletions test/components/tools/test_tool_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
#
# SPDX-License-Identifier: Apache-2.0

import json
import os
import pytest
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from haystack import component
from pydantic import BaseModel
from haystack import Pipeline
from haystack.dataclasses import Document
from haystack_experimental.dataclasses import ChatMessage, ChatRole
from haystack_experimental.components.tools.tool_invoker import ToolInvoker
from haystack_experimental.components.generators.chat import OpenAIChatGenerator
Expand Down Expand Up @@ -125,6 +127,21 @@ def run(self, person: Person) -> Dict[str, str]:
}


@component
class DocumentProcessor:
"""A component that processes a list of Documents."""

@component.output_types(concatenated=str)
def run(self, documents: List[Document]) -> Dict[str, str]:
"""
Concatenates the content of multiple documents with newlines.
:param documents: List of Documents whose content will be concatenated
:returns: Dictionary containing the concatenated document contents
"""
return {"concatenated": '\n'.join(doc.content for doc in documents)}


## Unit tests
class TestToolComponent:
def test_from_component_basic(self):
Expand Down Expand Up @@ -311,6 +328,104 @@ def test_from_component_with_nested_dataclass(self):
assert "info" in result
assert result["info"] == "Diana lives at 123 Elm Street, Metropolis."

def test_from_component_with_document_list(self):
component = DocumentProcessor()

tool = Tool.from_component(
component=component,
name="document_processor",
description="A tool that concatenates document contents"
)

assert tool.parameters == {
"type": "object",
"properties": {
"documents": {
"type": "array",
"description": "List of Documents whose content will be concatenated",
"items": {
"type": "object",
"properties": {
"id": {
"type": "string",
"description": "Field 'id' of 'Document'."
},
"content": {
"type": "string",
"description": "Field 'content' of 'Document'."
},
"dataframe": {
"type": "string",
"description": "Field 'dataframe' of 'Document'."
},
"blob": {
"type": "object",
"description": "Field 'blob' of 'Document'.",
"properties": {
"data": {
"type": "string",
"description": "Field 'data' of 'ByteStream'."
},
"meta": {
"type": "string",
"description": "Field 'meta' of 'ByteStream'."
},
"mime_type": {
"type": "string",
"description": "Field 'mime_type' of 'ByteStream'."
}
},
"required": ["data"]
},
"meta": {
"type": "string",
"description": "Field 'meta' of 'Document'."
},
"score": {
"type": "number",
"description": "Field 'score' of 'Document'."
},
"embedding": {
"type": "array",
"description": "Field 'embedding' of 'Document'.",
"items": {
"type": "number"
}
},
"sparse_embedding": {
"type": "object",
"description": "Field 'sparse_embedding' of 'Document'.",
"properties": {
"indices": {
"type": "array",
"description": "Field 'indices' of 'SparseEmbedding'.",
"items": {
"type": "integer"
}
},
"values": {
"type": "array",
"description": "Field 'values' of 'SparseEmbedding'.",
"items": {
"type": "number"
}
}
},
"required": ["indices", "values"]
}
}
}
}
},
"required": ["documents"]
}

# Test tool invocation
result = tool.invoke(documents=[{"content": "First document"}, {"content": "Second document"}])
assert isinstance(result, dict)
assert "concatenated" in result
assert result["concatenated"] == "First document\nSecond document"


## Integration tests
class TestToolComponentInPipelineWithOpenAI:
Expand Down Expand Up @@ -452,6 +567,38 @@ def test_person_processor_in_pipeline(self):
assert "Diana" in tool_message.tool_call_result.result and "Metropolis" in tool_message.tool_call_result.result
assert not tool_message.tool_call_result.error

@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
@pytest.mark.integration
def test_document_processor_in_pipeline(self):
component = DocumentProcessor()
tool = Tool.from_component(
component=component,
name="document_processor",
description="A tool that concatenates the content of multiple documents"
)

pipeline = Pipeline()
pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o", tools=[tool]))
pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool], convert_result_to_json_string=True))
pipeline.connect("llm.replies", "tool_invoker.messages")

message = ChatMessage.from_user(
text="I have two documents. First one says 'Hello world' and second one says 'Goodbye world'. Can you concatenate them?"
)

result = pipeline.run({"llm": {"messages": [message]}})

tool_messages = result["tool_invoker"]["tool_messages"]
assert len(tool_messages) == 1

tool_message = tool_messages[0]
assert tool_message.is_from(ChatRole.TOOL)
result = json.loads(tool_message.tool_call_result.result)
assert "concatenated" in result
assert "Hello world" in result["concatenated"]
assert "Goodbye world" in result["concatenated"]
assert not tool_message.tool_call_result.error




Expand Down

0 comments on commit 1c259f9

Please sign in to comment.