From ea3602643aa52c27f3bea7bf5bc90b97f568dcdc Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Tue, 17 Dec 2024 17:02:04 +0100 Subject: [PATCH] feat!: new `ChatMessage` (#8640) * draft * del HF token in tests * adaptations * progress * fix type * import sorting * more control on deserialization * release note * improvements * support name field * fix chatpromptbuilder test * Update chat_message.py --------- Co-authored-by: Daria Fokina --- .../builders/chat_prompt_builder.py | 6 +- .../generators/chat/hugging_face_api.py | 7 +- .../components/generators/openai_utils.py | 9 +- haystack/dataclasses/__init__.py | 5 +- haystack/dataclasses/chat_message.py | 336 +++++++++++++++--- .../new-chatmessage-7f47d5bdeb6ad6f5.yaml | 23 ++ .../builders/test_chat_prompt_builder.py | 22 +- .../generators/chat/test_hugging_face_api.py | 7 - .../generators/test_openai_utils.py | 7 - .../routers/test_conditional_router.py | 6 +- test/core/pipeline/features/test_run.py | 6 +- test/dataclasses/test_chat_message.py | 283 +++++++++++---- 12 files changed, 560 insertions(+), 157 deletions(-) create mode 100644 releasenotes/notes/new-chatmessage-7f47d5bdeb6ad6f5.yaml diff --git a/haystack/components/builders/chat_prompt_builder.py b/haystack/components/builders/chat_prompt_builder.py index fd9969f5b7..33e2feda2d 100644 --- a/haystack/components/builders/chat_prompt_builder.py +++ b/haystack/components/builders/chat_prompt_builder.py @@ -9,7 +9,7 @@ from jinja2.sandbox import SandboxedEnvironment from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses.chat_message import ChatMessage, ChatRole +from haystack.dataclasses.chat_message import ChatMessage, ChatRole, TextContent logger = logging.getLogger(__name__) @@ -197,10 +197,10 @@ def run( if message.text is None: raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {message}") compiled_template = self._env.from_string(message.text) - rendered_content = compiled_template.render(template_variables_combined) + rendered_text = compiled_template.render(template_variables_combined) # deep copy the message to avoid modifying the original message rendered_message: ChatMessage = deepcopy(message) - rendered_message.content = rendered_content + rendered_message._content = [TextContent(text=rendered_text)] processed_messages.append(rendered_message) else: processed_messages.append(message) diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index d4ecd53f10..8711a9175a 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -25,13 +25,8 @@ def _convert_message_to_hfapi_format(message: ChatMessage) -> Dict[str, str]: :returns: A dictionary with the following keys: - `role` - `content` - - `name` (optional) """ - formatted_msg = {"role": message.role.value, "content": message.content} - if message.name: - formatted_msg["name"] = message.name - - return formatted_msg + return {"role": message.role.value, "content": message.text or ""} @component diff --git a/haystack/components/generators/openai_utils.py b/haystack/components/generators/openai_utils.py index 5b1838c386..ab6d5e7b1d 100644 --- a/haystack/components/generators/openai_utils.py +++ b/haystack/components/generators/openai_utils.py @@ -13,16 +13,11 @@ def _convert_message_to_openai_format(message: ChatMessage) -> Dict[str, str]: See the [API reference](https://platform.openai.com/docs/api-reference/chat/create) for details. - :returns: A dictionary with the following key: + :returns: A dictionary with the following keys: - `role` - `content` - - `name` (optional) """ if message.text is None: raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {message}") - openai_msg = {"role": message.role.value, "content": message.text} - if message.name: - openai_msg["name"] = message.name - - return openai_msg + return {"role": message.role.value, "content": message.text} diff --git a/haystack/dataclasses/__init__.py b/haystack/dataclasses/__init__.py index 231ce80713..91e8f0408f 100644 --- a/haystack/dataclasses/__init__.py +++ b/haystack/dataclasses/__init__.py @@ -4,7 +4,7 @@ from haystack.dataclasses.answer import Answer, ExtractedAnswer, GeneratedAnswer from haystack.dataclasses.byte_stream import ByteStream -from haystack.dataclasses.chat_message import ChatMessage, ChatRole +from haystack.dataclasses.chat_message import ChatMessage, ChatRole, TextContent, ToolCall, ToolCallResult from haystack.dataclasses.document import Document from haystack.dataclasses.sparse_embedding import SparseEmbedding from haystack.dataclasses.streaming_chunk import StreamingChunk @@ -17,6 +17,9 @@ "ByteStream", "ChatMessage", "ChatRole", + "ToolCall", + "ToolCallResult", + "TextContent", "StreamingChunk", "SparseEmbedding", ] diff --git a/haystack/dataclasses/chat_message.py b/haystack/dataclasses/chat_message.py index fb15ee6f5e..5aadb9f752 100644 --- a/haystack/dataclasses/chat_message.py +++ b/haystack/dataclasses/chat_message.py @@ -5,104 +5,318 @@ import warnings from dataclasses import asdict, dataclass, field from enum import Enum -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional, Sequence, Union + +LEGACY_INIT_PARAMETERS = {"role", "content", "meta", "name"} class ChatRole(str, Enum): - """Enumeration representing the roles within a chat.""" + """ + Enumeration representing the roles within a chat. + """ - ASSISTANT = "assistant" + #: The user role. A message from the user contains only text. USER = "user" + + #: The system role. A message from the system contains only text. SYSTEM = "system" + + #: The assistant role. A message from the assistant can contain text and Tool calls. It can also store metadata. + ASSISTANT = "assistant" + + #: The tool role. A message from a tool contains the result of a Tool invocation. + TOOL = "tool" + + #: The function role. Deprecated in favor of `TOOL`. FUNCTION = "function" + @staticmethod + def from_str(string: str) -> "ChatRole": + """ + Convert a string to a ChatRole enum. + """ + enum_map = {e.value: e for e in ChatRole} + role = enum_map.get(string) + if role is None: + msg = f"Unknown chat role '{string}'. Supported roles are: {list(enum_map.keys())}" + raise ValueError(msg) + return role + + +@dataclass +class ToolCall: + """ + Represents a Tool call prepared by the model, usually contained in an assistant message. + + :param id: The ID of the Tool call. + :param tool_name: The name of the Tool to call. + :param arguments: The arguments to call the Tool with. + """ + + tool_name: str + arguments: Dict[str, Any] + id: Optional[str] = None # noqa: A003 + + +@dataclass +class ToolCallResult: + """ + Represents the result of a Tool invocation. + + :param result: The result of the Tool invocation. + :param origin: The Tool call that produced this result. + :param error: Whether the Tool invocation resulted in an error. + """ + + result: str + origin: ToolCall + error: bool + + +@dataclass +class TextContent: + """ + The textual content of a chat message. + + :param text: The text content of the message. + """ + + text: str + + +ChatMessageContentT = Union[TextContent, ToolCall, ToolCallResult] + @dataclass class ChatMessage: """ Represents a message in a LLM chat conversation. - :param content: The text content of the message. - :param role: The role of the entity sending the message. - :param name: The name of the function being called (only applicable for role FUNCTION). - :param meta: Additional metadata associated with the message. + Use the `from_assistant`, `from_user`, `from_system`, and `from_tool` class methods to create a ChatMessage. """ - content: str - role: ChatRole - name: Optional[str] - meta: Dict[str, Any] = field(default_factory=dict, hash=False) + _role: ChatRole + _content: Sequence[ChatMessageContentT] + _name: Optional[str] = None + _meta: Dict[str, Any] = field(default_factory=dict, hash=False) - @property - def text(self) -> Optional[str]: + def __new__(cls, *args, **kwargs): """ - Returns the textual content of the message. + This method is reimplemented to make the changes to the `ChatMessage` dataclass more visible. """ - # Currently, this property mirrors the `content` attribute. This will change in 2.9.0. - # The current actual return type is str. We are using Optional[str] to be ready for 2.9.0, - # when None will be a valid value for `text`. - return object.__getattribute__(self, "content") + + general_msg = ( + "Use the `from_assistant`, `from_user`, `from_system`, and `from_tool` class methods to create a " + "ChatMessage. For more information about the new API and how to migrate, see the documentation:" + " https://docs.haystack.deepset.ai/docs/data-classes#chatmessage" + ) + + if any(param in kwargs for param in LEGACY_INIT_PARAMETERS): + raise TypeError( + "The `role`, `content`, `meta`, and `name` init parameters of `ChatMessage` have been removed. " + f"{general_msg}" + ) + + allowed_content_types = (TextContent, ToolCall, ToolCallResult) + if len(args) > 1 and not isinstance(args[1], allowed_content_types): + raise TypeError( + "The `_content` parameter of `ChatMessage` must be one of the following types: " + f"{', '.join(t.__name__ for t in allowed_content_types)}. " + f"{general_msg}" + ) + + return super(ChatMessage, cls).__new__(cls) + + def __post_init__(self): + if self._role == ChatRole.FUNCTION: + msg = "The `FUNCTION` role has been deprecated in favor of `TOOL` and will be removed in 2.10.0. " + warnings.warn(msg, DeprecationWarning) def __getattribute__(self, name): - # this method is reimplemented to warn about the deprecation of the `content` attribute + """ + This method is reimplemented to make the `content` attribute removal more visible. + """ + if name == "content": msg = ( - "The `content` attribute of `ChatMessage` will be removed in Haystack 2.9.0. " - "Use the `text` property to access the textual value." + "The `content` attribute of `ChatMessage` has been removed. " + "Use the `text` property to access the textual value. " + "For more information about the new API and how to migrate, see the documentation: " + "https://docs.haystack.deepset.ai/docs/data-classes#chatmessage" ) - warnings.warn(msg, DeprecationWarning) + raise AttributeError(msg) return object.__getattribute__(self, name) - def is_from(self, role: ChatRole) -> bool: + def __len__(self): + return len(self._content) + + @property + def role(self) -> ChatRole: + """ + Returns the role of the entity sending the message. + """ + return self._role + + @property + def meta(self) -> Dict[str, Any]: + """ + Returns the metadata associated with the message. + """ + return self._meta + + @property + def name(self) -> Optional[str]: + """ + Returns the name associated with the message. + """ + return self._name + + @property + def texts(self) -> List[str]: + """ + Returns the list of all texts contained in the message. + """ + return [content.text for content in self._content if isinstance(content, TextContent)] + + @property + def text(self) -> Optional[str]: + """ + Returns the first text contained in the message. + """ + if texts := self.texts: + return texts[0] + return None + + @property + def tool_calls(self) -> List[ToolCall]: + """ + Returns the list of all Tool calls contained in the message. + """ + return [content for content in self._content if isinstance(content, ToolCall)] + + @property + def tool_call(self) -> Optional[ToolCall]: + """ + Returns the first Tool call contained in the message. + """ + if tool_calls := self.tool_calls: + return tool_calls[0] + return None + + @property + def tool_call_results(self) -> List[ToolCallResult]: + """ + Returns the list of all Tool call results contained in the message. + """ + return [content for content in self._content if isinstance(content, ToolCallResult)] + + @property + def tool_call_result(self) -> Optional[ToolCallResult]: + """ + Returns the first Tool call result contained in the message. + """ + if tool_call_results := self.tool_call_results: + return tool_call_results[0] + return None + + def is_from(self, role: Union[ChatRole, str]) -> bool: """ Check if the message is from a specific role. :param role: The role to check against. :returns: True if the message is from the specified role, False otherwise. """ - return self.role == role + if isinstance(role, str): + role = ChatRole.from_str(role) + return self._role == role @classmethod - def from_assistant(cls, content: str, meta: Optional[Dict[str, Any]] = None) -> "ChatMessage": + def from_user(cls, text: str, meta: Optional[Dict[str, Any]] = None, name: Optional[str] = None) -> "ChatMessage": """ - Create a message from the assistant. + Create a message from the user. - :param content: The text content of the message. + :param text: The text content of the message. :param meta: Additional metadata associated with the message. + :param name: An optional name for the participant. This field is only supported by OpenAI. :returns: A new ChatMessage instance. """ - return cls(content, ChatRole.ASSISTANT, None, meta or {}) + return cls(_role=ChatRole.USER, _content=[TextContent(text=text)], _meta=meta or {}, _name=name) @classmethod - def from_user(cls, content: str) -> "ChatMessage": + def from_system(cls, text: str, meta: Optional[Dict[str, Any]] = None, name: Optional[str] = None) -> "ChatMessage": """ - Create a message from the user. + Create a message from the system. - :param content: The text content of the message. + :param text: The text content of the message. + :param meta: Additional metadata associated with the message. + :param name: An optional name for the participant. This field is only supported by OpenAI. :returns: A new ChatMessage instance. """ - return cls(content, ChatRole.USER, None) + return cls(_role=ChatRole.SYSTEM, _content=[TextContent(text=text)], _meta=meta or {}, _name=name) @classmethod - def from_system(cls, content: str) -> "ChatMessage": + def from_assistant( + cls, + text: Optional[str] = None, + meta: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + tool_calls: Optional[List[ToolCall]] = None, + ) -> "ChatMessage": """ - Create a message from the system. + Create a message from the assistant. - :param content: The text content of the message. + :param text: The text content of the message. + :param meta: Additional metadata associated with the message. + :param tool_calls: The Tool calls to include in the message. + :param name: An optional name for the participant. This field is only supported by OpenAI. :returns: A new ChatMessage instance. """ - return cls(content, ChatRole.SYSTEM, None) + content: List[ChatMessageContentT] = [] + if text is not None: + content.append(TextContent(text=text)) + if tool_calls: + content.extend(tool_calls) + + return cls(_role=ChatRole.ASSISTANT, _content=content, _meta=meta or {}, _name=name) + + @classmethod + def from_tool( + cls, tool_result: str, origin: ToolCall, error: bool = False, meta: Optional[Dict[str, Any]] = None + ) -> "ChatMessage": + """ + Create a message from a Tool. + + :param tool_result: The result of the Tool invocation. + :param origin: The Tool call that produced this result. + :param error: Whether the Tool invocation resulted in an error. + :param meta: Additional metadata associated with the message. + :returns: A new ChatMessage instance. + """ + return cls( + _role=ChatRole.TOOL, + _content=[ToolCallResult(result=tool_result, origin=origin, error=error)], + _meta=meta or {}, + ) @classmethod def from_function(cls, content: str, name: str) -> "ChatMessage": """ - Create a message from a function call. + Create a message from a function call. Deprecated in favor of `from_tool`. :param content: The text content of the message. :param name: The name of the function being called. :returns: A new ChatMessage instance. """ - return cls(content, ChatRole.FUNCTION, name) + msg = ( + "The `from_function` method is deprecated and will be removed in version 2.10.0. " + "Its behavior has changed: it now attempts to convert legacy function messages to tool messages. " + "This conversion is not guaranteed to succeed in all scenarios. " + "Please migrate to `ChatMessage.from_tool` and carefully verify the results if you " + "continue to use this method." + ) + warnings.warn(msg) + + return cls.from_tool(content, ToolCall(id=None, tool_name=name, arguments={}), error=False) def to_dict(self) -> Dict[str, Any]: """ @@ -111,10 +325,23 @@ def to_dict(self) -> Dict[str, Any]: :returns: Serialized version of the object. """ - data = asdict(self) - data["role"] = self.role.value + serialized: Dict[str, Any] = {} + serialized["_role"] = self._role.value + serialized["_meta"] = self._meta + serialized["_name"] = self._name + content: List[Dict[str, Any]] = [] + for part in self._content: + if isinstance(part, TextContent): + content.append({"text": part.text}) + elif isinstance(part, ToolCall): + content.append({"tool_call": asdict(part)}) + elif isinstance(part, ToolCallResult): + content.append({"tool_call_result": asdict(part)}) + else: + raise TypeError(f"Unsupported type in ChatMessage content: `{type(part).__name__}` for `{part}`.") - return data + serialized["_content"] = content + return serialized @classmethod def from_dict(cls, data: Dict[str, Any]) -> "ChatMessage": @@ -126,6 +353,31 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChatMessage": :returns: The created object. """ - data["role"] = ChatRole(data["role"]) + if any(param in data for param in LEGACY_INIT_PARAMETERS): + raise TypeError( + "The `role`, `content`, `meta`, and `name` init parameters of `ChatMessage` have been removed. " + "For more information about the new API and how to migrate, see the documentation: " + "https://docs.haystack.deepset.ai/docs/data-classes#chatmessage" + ) + + data["_role"] = ChatRole(data["_role"]) + + content: List[ChatMessageContentT] = [] + + for part in data["_content"]: + if "text" in part: + content.append(TextContent(text=part["text"])) + elif "tool_call" in part: + content.append(ToolCall(**part["tool_call"])) + elif "tool_call_result" in part: + result = part["tool_call_result"]["result"] + origin = ToolCall(**part["tool_call_result"]["origin"]) + error = part["tool_call_result"]["error"] + tcr = ToolCallResult(result=result, origin=origin, error=error) + content.append(tcr) + else: + raise ValueError(f"Unsupported content in serialized ChatMessage: `{part}`") + + data["_content"] = content return cls(**data) diff --git a/releasenotes/notes/new-chatmessage-7f47d5bdeb6ad6f5.yaml b/releasenotes/notes/new-chatmessage-7f47d5bdeb6ad6f5.yaml new file mode 100644 index 0000000000..b9e590e590 --- /dev/null +++ b/releasenotes/notes/new-chatmessage-7f47d5bdeb6ad6f5.yaml @@ -0,0 +1,23 @@ +--- +highlights: > + We are introducing a refactored ChatMessage dataclass. It is more flexible, future-proof, and compatible with + different types of content: text, tool calls, tool calls results. + For information about the new API and how to migrate, see the documentation: + https://docs.haystack.deepset.ai/docs/data-classes#chatmessage +upgrade: + - | + The refactoring of the ChatMessage dataclass includes some breaking changes, involving ChatMessage creation and + accessing attributes. If you have a Pipeline containing a ChatPromptBuilder, serialized using Haystack<2.9.0, + deserialization may break. + For detailed information about the changes and how to migrate, see the documentation: + https://docs.haystack.deepset.ai/docs/data-classes#chatmessage +features: + - | + Changed the ChatMessage dataclass to support different types of content, including tool calls, and tool call + results. +deprecations: + - | + The function role and ChatMessage.from_function class method have been deprecated and will be removed in + Haystack 2.10.0. ChatMessage.from_function also attempts to produce a valid tool message. + For more information, see the documentation: + https://docs.haystack.deepset.ai/docs/data-classes#chatmessage diff --git a/test/components/builders/test_chat_prompt_builder.py b/test/components/builders/test_chat_prompt_builder.py index 5e1ae6132e..a8fb8bc5b8 100644 --- a/test/components/builders/test_chat_prompt_builder.py +++ b/test/components/builders/test_chat_prompt_builder.py @@ -13,8 +13,8 @@ class TestChatPromptBuilder: def test_init(self): builder = ChatPromptBuilder( template=[ - ChatMessage.from_user(content="This is a {{ variable }}"), - ChatMessage.from_system(content="This is a {{ variable2 }}"), + ChatMessage.from_user("This is a {{ variable }}"), + ChatMessage.from_system("This is a {{ variable2 }}"), ] ) assert builder.required_variables == [] @@ -531,8 +531,13 @@ def test_to_dict(self): "type": "haystack.components.builders.chat_prompt_builder.ChatPromptBuilder", "init_parameters": { "template": [ - {"content": "text and {var}", "role": "user", "name": None, "meta": {}}, - {"content": "content {required_var}", "role": "assistant", "name": None, "meta": {}}, + {"_content": [{"text": "text and {var}"}], "_role": "user", "_meta": {}, "_name": None}, + { + "_content": [{"text": "content {required_var}"}], + "_role": "assistant", + "_meta": {}, + "_name": None, + }, ], "variables": ["var", "required_var"], "required_variables": ["required_var"], @@ -545,8 +550,13 @@ def test_from_dict(self): "type": "haystack.components.builders.chat_prompt_builder.ChatPromptBuilder", "init_parameters": { "template": [ - {"content": "text and {var}", "role": "user", "name": None, "meta": {}}, - {"content": "content {required_var}", "role": "assistant", "name": None, "meta": {}}, + {"_content": [{"text": "text and {var}"}], "_role": "user", "_meta": {}, "_name": None}, + { + "_content": [{"text": "content {required_var}"}], + "_role": "assistant", + "_meta": {}, + "_name": None, + }, ], "variables": ["var", "required_var"], "required_variables": ["required_var"], diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py index 3d7fd617c0..e60ec863ab 100644 --- a/test/components/generators/chat/test_hugging_face_api.py +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -68,13 +68,6 @@ def test_convert_message_to_hfapi_format(): message = ChatMessage.from_user("I have a question") assert _convert_message_to_hfapi_format(message) == {"role": "user", "content": "I have a question"} - message = ChatMessage.from_function("Function call", "function_name") - assert _convert_message_to_hfapi_format(message) == { - "role": "function", - "content": "Function call", - "name": "function_name", - } - class TestHuggingFaceAPIGenerator: def test_init_invalid_api_type(self): diff --git a/test/components/generators/test_openai_utils.py b/test/components/generators/test_openai_utils.py index 226b32f811..916a3e3d70 100644 --- a/test/components/generators/test_openai_utils.py +++ b/test/components/generators/test_openai_utils.py @@ -14,10 +14,3 @@ def test_convert_message_to_openai_format(): message = ChatMessage.from_user("I have a question") assert _convert_message_to_openai_format(message) == {"role": "user", "content": "I have a question"} - - message = ChatMessage.from_function("Function call", "function_name") - assert _convert_message_to_openai_format(message) == { - "role": "function", - "content": "Function call", - "name": "function_name", - } diff --git a/test/components/routers/test_conditional_router.py b/test/components/routers/test_conditional_router.py index e0f3552319..66d941b645 100644 --- a/test/components/routers/test_conditional_router.py +++ b/test/components/routers/test_conditional_router.py @@ -349,7 +349,7 @@ def test_unsafe(self): ] router = ConditionalRouter(routes, unsafe=True) streams = [1] - message = ChatMessage.from_user(content="This is a message") + message = ChatMessage.from_user("This is a message") res = router.run(streams=streams, message=message) assert res == {"message": message} @@ -370,7 +370,7 @@ def test_validate_output_type_without_unsafe(self): ] router = ConditionalRouter(routes, validate_output_type=True) streams = [1] - message = ChatMessage.from_user(content="This is a message") + message = ChatMessage.from_user("This is a message") with pytest.raises(ValueError, match="Route 'message' type doesn't match expected type"): router.run(streams=streams, message=message) @@ -391,7 +391,7 @@ def test_validate_output_type_with_unsafe(self): ] router = ConditionalRouter(routes, unsafe=True, validate_output_type=True) streams = [1] - message = ChatMessage.from_user(content="This is a message") + message = ChatMessage.from_user("This is a message") res = router.run(streams=streams, message=message) assert isinstance(res["message"], ChatMessage) diff --git a/test/core/pipeline/features/test_run.py b/test/core/pipeline/features/test_run.py index d7001a0187..8f07dfec99 100644 --- a/test/core/pipeline/features/test_run.py +++ b/test/core/pipeline/features/test_run.py @@ -1657,7 +1657,7 @@ def run(self, query: str): class ToolExtractor: @component.output_types(output=List[str]) def run(self, messages: List[ChatMessage]): - prompt: str = messages[-1].content + prompt: str = messages[-1].text lines = prompt.strip().split("\n") for line in reversed(lines): pattern = r"Action:\s*(\w+)\[(.*?)\]" @@ -1678,14 +1678,14 @@ def __init__(self, suffix: str = ""): @component.output_types(output=List[ChatMessage]) def run(self, replies: List[ChatMessage], current_prompt: List[ChatMessage]): - content = current_prompt[-1].content + replies[-1].content + self._suffix + content = current_prompt[-1].text + replies[-1].text + self._suffix return {"output": [ChatMessage.from_user(content)]} @component class SearchOutputAdapter: @component.output_types(output=List[ChatMessage]) def run(self, replies: List[ChatMessage]): - content = f"Observation: {replies[-1].content}\n" + content = f"Observation: {replies[-1].text}\n" return {"output": [ChatMessage.from_assistant(content)]} pipeline.add_component("prompt_concatenator_after_action", PromptConcatenator()) diff --git a/test/dataclasses/test_chat_message.py b/test/dataclasses/test_chat_message.py index 30ad51630e..832617e712 100644 --- a/test/dataclasses/test_chat_message.py +++ b/test/dataclasses/test_chat_message.py @@ -4,64 +4,240 @@ import pytest from transformers import AutoTokenizer -from haystack.dataclasses import ChatMessage, ChatRole +from haystack.dataclasses.chat_message import ChatMessage, ChatRole, ToolCall, ToolCallResult, TextContent from haystack.components.generators.openai_utils import _convert_message_to_openai_format +def test_tool_call_init(): + tc = ToolCall(id="123", tool_name="mytool", arguments={"a": 1}) + assert tc.id == "123" + assert tc.tool_name == "mytool" + assert tc.arguments == {"a": 1} + + +def test_tool_call_result_init(): + tcr = ToolCallResult(result="result", origin=ToolCall(id="123", tool_name="mytool", arguments={"a": 1}), error=True) + assert tcr.result == "result" + assert tcr.origin == ToolCall(id="123", tool_name="mytool", arguments={"a": 1}) + assert tcr.error + + +def test_text_content_init(): + tc = TextContent(text="Hello") + assert tc.text == "Hello" + + def test_from_assistant_with_valid_content(): - content = "Hello, how can I assist you?" - message = ChatMessage.from_assistant(content) - assert message.content == content - assert message.text == content + text = "Hello, how can I assist you?" + message = ChatMessage.from_assistant(text) + assert message.role == ChatRole.ASSISTANT + assert message._content == [TextContent(text)] + assert message.name is None + + assert message.text == text + assert message.texts == [text] + + assert not message.tool_calls + assert not message.tool_call + assert not message.tool_call_results + assert not message.tool_call_result + + +def test_from_assistant_with_tool_calls(): + tool_calls = [ + ToolCall(id="123", tool_name="mytool", arguments={"a": 1}), + ToolCall(id="456", tool_name="mytool2", arguments={"b": 2}), + ] + + message = ChatMessage.from_assistant(tool_calls=tool_calls) + + assert message.role == ChatRole.ASSISTANT + assert message._content == tool_calls + + assert message.tool_calls == tool_calls + assert message.tool_call == tool_calls[0] + + assert not message.texts + assert not message.text + assert not message.tool_call_results + assert not message.tool_call_result def test_from_user_with_valid_content(): - content = "I have a question." - message = ChatMessage.from_user(content) - assert message.content == content - assert message.text == content + text = "I have a question." + message = ChatMessage.from_user(text=text) + assert message.role == ChatRole.USER + assert message._content == [TextContent(text)] + assert message.name is None + + assert message.text == text + assert message.texts == [text] + + assert not message.tool_calls + assert not message.tool_call + assert not message.tool_call_results + assert not message.tool_call_result + + +def test_from_user_with_name(): + text = "I have a question." + message = ChatMessage.from_user(text=text, name="John") + + assert message.name == "John" + assert message.role == ChatRole.USER + assert message._content == [TextContent(text)] def test_from_system_with_valid_content(): - content = "System message." - message = ChatMessage.from_system(content) - assert message.content == content - assert message.text == content + text = "I have a question." + message = ChatMessage.from_system(text=text) + assert message.role == ChatRole.SYSTEM + assert message._content == [TextContent(text)] + assert message.text == text + assert message.texts == [text] -def test_with_empty_content(): - message = ChatMessage.from_user("") - assert message.content == "" - assert message.text == "" - assert message.role == ChatRole.USER + assert not message.tool_calls + assert not message.tool_call + assert not message.tool_call_results + assert not message.tool_call_result + + +def test_from_tool_with_valid_content(): + tool_result = "Tool result" + origin = ToolCall(id="123", tool_name="mytool", arguments={"a": 1}) + message = ChatMessage.from_tool(tool_result, origin, error=False) + + tcr = ToolCallResult(result=tool_result, origin=origin, error=False) + + assert message._content == [tcr] + assert message.role == ChatRole.TOOL + + assert message.tool_call_result == tcr + assert message.tool_call_results == [tcr] + + assert not message.tool_calls + assert not message.tool_call + assert not message.texts + assert not message.text + + +def test_multiple_text_segments(): + texts = [TextContent(text="Hello"), TextContent(text="World")] + message = ChatMessage(_role=ChatRole.USER, _content=texts) + + assert message.texts == ["Hello", "World"] + assert len(message) == 2 + + +def test_mixed_content(): + content = [TextContent(text="Hello"), ToolCall(id="123", tool_name="mytool", arguments={"a": 1})] + + message = ChatMessage(_role=ChatRole.ASSISTANT, _content=content) + assert len(message) == 2 + assert message.texts == ["Hello"] + assert message.text == "Hello" -def test_from_function_with_empty_name(): - content = "Function call" - message = ChatMessage.from_function(content, "") - assert message.content == content - assert message.text == content - assert message.name == "" - assert message.role == ChatRole.FUNCTION + assert message.tool_calls == [content[1]] + assert message.tool_call == content[1] -def test_to_openai_format(): - message = ChatMessage.from_system("You are good assistant") - assert _convert_message_to_openai_format(message) == {"role": "system", "content": "You are good assistant"} +def test_from_function(): + # check warning is raised + with pytest.warns(): + message = ChatMessage.from_function("Result of function invocation", "my_function") - message = ChatMessage.from_user("I have a question") - assert _convert_message_to_openai_format(message) == {"role": "user", "content": "I have a question"} + assert message.role == ChatRole.TOOL + assert message.tool_call_result == ToolCallResult( + result="Result of function invocation", + origin=ToolCall(id=None, tool_name="my_function", arguments={}), + error=False, + ) + + +def test_serde(): + # the following message is created just for testing purposes and does not make sense in a real use case + + role = ChatRole.ASSISTANT + + text_content = TextContent(text="Hello") + tool_call = ToolCall(id="123", tool_name="mytool", arguments={"a": 1}) + tool_call_result = ToolCallResult(result="result", origin=tool_call, error=False) + meta = {"some": "info"} - message = ChatMessage.from_function("Function call", "function_name") - assert _convert_message_to_openai_format(message) == { - "role": "function", - "content": "Function call", - "name": "function_name", + message = ChatMessage(_role=role, _content=[text_content, tool_call, tool_call_result], _meta=meta) + + serialized_message = message.to_dict() + assert serialized_message == { + "_content": [ + {"text": "Hello"}, + {"tool_call": {"id": "123", "tool_name": "mytool", "arguments": {"a": 1}}}, + { + "tool_call_result": { + "result": "result", + "error": False, + "origin": {"id": "123", "tool_name": "mytool", "arguments": {"a": 1}}, + } + }, + ], + "_role": "assistant", + "_name": None, + "_meta": {"some": "info"}, } + deserialized_message = ChatMessage.from_dict(serialized_message) + assert deserialized_message == message + + +def test_to_dict_with_invalid_content_type(): + text_content = TextContent(text="Hello") + invalid_content = "invalid" + + message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[text_content, invalid_content]) + + with pytest.raises(TypeError): + message.to_dict() + + +def test_from_dict_with_invalid_content_type(): + data = {"_role": "assistant", "_content": [{"text": "Hello"}, "invalid"]} + with pytest.raises(ValueError): + ChatMessage.from_dict(data) + + data = {"_role": "assistant", "_content": [{"text": "Hello"}, {"invalid": "invalid"}]} + with pytest.raises(ValueError): + ChatMessage.from_dict(data) + + +def test_from_dict_with_legacy_init_parameters(): + with pytest.raises(TypeError): + ChatMessage.from_dict({"role": "user", "content": "This is a message"}) + + +def test_chat_message_content_attribute_removed(): + message = ChatMessage.from_user(text="This is a message") + with pytest.raises(AttributeError): + message.content + + +def test_chat_message_init_parameters_removed(): + with pytest.raises(TypeError): + ChatMessage(role="irrelevant", content="This is a message") + + +def test_chat_message_init_content_parameter_type(): + with pytest.raises(TypeError): + ChatMessage(ChatRole.USER, "This is a message") + + +def test_chat_message_function_role_deprecated(): + with pytest.warns(DeprecationWarning): + ChatMessage(ChatRole.FUNCTION, TextContent("This is a message")) + @pytest.mark.integration def test_apply_chat_templating_on_chat_message(): @@ -93,40 +269,3 @@ def test_apply_custom_chat_templating_on_chat_message(): formatted_messages, chat_template=anthropic_template, tokenize=False ) assert tokenized_messages == "You are good assistant\nHuman: I have a question\nAssistant:" - - -def test_to_dict(): - content = "content" - role = "user" - meta = {"some": "some"} - - message = ChatMessage.from_user(content) - message.meta.update(meta) - - assert message.text == content - assert message.to_dict() == {"content": content, "role": role, "name": None, "meta": meta} - - -def test_from_dict(): - assert ChatMessage.from_dict(data={"content": "text", "role": "user", "name": None}) == ChatMessage.from_user( - "text" - ) - - -def test_from_dict_with_meta(): - data = {"content": "text", "role": "assistant", "name": None, "meta": {"something": "something"}} - assert ChatMessage.from_dict(data) == ChatMessage.from_assistant("text", meta={"something": "something"}) - - -def test_content_deprecation_warning(recwarn): - message = ChatMessage.from_user("my message") - - # accessing the content attribute triggers the deprecation warning - _ = message.content - assert len(recwarn) == 1 - wrn = recwarn.pop(DeprecationWarning) - assert "`content` attribute" in wrn.message.args[0] - - # accessing the text property does not trigger a warning - assert message.text == "my message" - assert len(recwarn) == 0