-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add ChatMessage class to Haystack 2.0 (#6144)
* Add ChatMessage and ChatRole
- Loading branch information
Showing
4 changed files
with
134 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
from haystack.preview.dataclasses.document import Document | ||
from haystack.preview.dataclasses.answer import ExtractedAnswer, GeneratedAnswer, Answer | ||
from haystack.preview.dataclasses.byte_stream import ByteStream | ||
from haystack.preview.dataclasses.chat_message import ChatMessage | ||
from haystack.preview.dataclasses.chat_message import ChatRole | ||
|
||
__all__ = ["Document", "ExtractedAnswer", "GeneratedAnswer", "Answer", "ByteStream"] | ||
__all__ = ["Document", "ExtractedAnswer", "GeneratedAnswer", "Answer", "ByteStream", "ChatMessage", "ChatRole"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from dataclasses import dataclass, field | ||
from enum import Enum | ||
from typing import Dict, Any, Optional | ||
|
||
|
||
class ChatRole(str, Enum): | ||
"""Enumeration representing the roles within a chat.""" | ||
|
||
ASSISTANT = "assistant" | ||
USER = "user" | ||
SYSTEM = "system" | ||
FUNCTION = "function" | ||
|
||
|
||
@dataclass | ||
class ChatMessage: | ||
""" | ||
Represents a message in a LLM chat conversation. | ||
:param content: The text content of the message. | ||
:param role: The role of the entity sending the message. | ||
:param name: The name of the function being called (only applicable for role FUNCTION). | ||
:param metadata: Additional metadata associated with the message. | ||
""" | ||
|
||
content: str | ||
role: ChatRole | ||
name: Optional[str] | ||
metadata: Dict[str, Any] = field(default_factory=dict, hash=False) | ||
|
||
def is_from(self, role: ChatRole) -> bool: | ||
""" | ||
Check if the message is from a specific role. | ||
:param role: The role to check against. | ||
:return: True if the message is from the specified role, False otherwise. | ||
""" | ||
return self.role == role | ||
|
||
@classmethod | ||
def from_assistant(cls, content: str) -> "ChatMessage": | ||
""" | ||
Create a message from the assistant. | ||
:param content: The text content of the message. | ||
:return: A new ChatMessage instance. | ||
""" | ||
return cls(content, ChatRole.ASSISTANT, None) | ||
|
||
@classmethod | ||
def from_user(cls, content: str) -> "ChatMessage": | ||
""" | ||
Create a message from the user. | ||
:param content: The text content of the message. | ||
:return: A new ChatMessage instance. | ||
""" | ||
return cls(content, ChatRole.USER, None) | ||
|
||
@classmethod | ||
def from_system(cls, content: str) -> "ChatMessage": | ||
""" | ||
Create a message from the system. | ||
:param content: The text content of the message. | ||
:return: A new ChatMessage instance. | ||
""" | ||
return cls(content, ChatRole.SYSTEM, None) | ||
|
||
@classmethod | ||
def from_function(cls, content: str, name: str) -> "ChatMessage": | ||
""" | ||
Create a message from a function call. | ||
:param content: The text content of the message. | ||
:param name: The name of the function being called. | ||
:return: A new ChatMessage instance. | ||
""" | ||
return cls(content, ChatRole.FUNCTION, name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
--- | ||
preview: | ||
- | | ||
Introduce ChatMessage data class to facilitate structured handling and processing of message content | ||
within LLM chat interactions. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import pytest | ||
|
||
from haystack.preview.dataclasses import ChatMessage, ChatRole | ||
|
||
|
||
@pytest.mark.unit | ||
def test_from_assistant_with_valid_content(): | ||
content = "Hello, how can I assist you?" | ||
message = ChatMessage.from_assistant(content) | ||
assert message.content == content | ||
assert message.role == ChatRole.ASSISTANT | ||
|
||
|
||
@pytest.mark.unit | ||
def test_from_user_with_valid_content(): | ||
content = "I have a question." | ||
message = ChatMessage.from_user(content) | ||
assert message.content == content | ||
assert message.role == ChatRole.USER | ||
|
||
|
||
@pytest.mark.unit | ||
def test_from_system_with_valid_content(): | ||
content = "System message." | ||
message = ChatMessage.from_system(content) | ||
assert message.content == content | ||
assert message.role == ChatRole.SYSTEM | ||
|
||
|
||
@pytest.mark.unit | ||
def test_with_empty_content(): | ||
message = ChatMessage("", ChatRole.USER, None) | ||
assert message.content == "" | ||
|
||
|
||
@pytest.mark.unit | ||
def test_with_invalid_role(): | ||
with pytest.raises(TypeError): | ||
ChatMessage("Invalid role", "invalid_role") | ||
|
||
|
||
@pytest.mark.unit | ||
def test_from_function_with_empty_name(): | ||
content = "Function call" | ||
message = ChatMessage.from_function(content, "") | ||
assert message.content == content | ||
assert message.name == "" |