Skip to content

Commit

Permalink
feat: Add ChatMessage class to Haystack 2.0 (#6144)
Browse files Browse the repository at this point in the history
* Add ChatMessage and ChatRole
  • Loading branch information
vblagoje authored Oct 23, 2023
1 parent 9d8979a commit dcc7e63
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 1 deletion.
4 changes: 3 additions & 1 deletion haystack/preview/dataclasses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from haystack.preview.dataclasses.document import Document
from haystack.preview.dataclasses.answer import ExtractedAnswer, GeneratedAnswer, Answer
from haystack.preview.dataclasses.byte_stream import ByteStream
from haystack.preview.dataclasses.chat_message import ChatMessage
from haystack.preview.dataclasses.chat_message import ChatRole

__all__ = ["Document", "ExtractedAnswer", "GeneratedAnswer", "Answer", "ByteStream"]
__all__ = ["Document", "ExtractedAnswer", "GeneratedAnswer", "Answer", "ByteStream", "ChatMessage", "ChatRole"]
79 changes: 79 additions & 0 deletions haystack/preview/dataclasses/chat_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Any, Optional


class ChatRole(str, Enum):
"""Enumeration representing the roles within a chat."""

ASSISTANT = "assistant"
USER = "user"
SYSTEM = "system"
FUNCTION = "function"


@dataclass
class ChatMessage:
"""
Represents a message in a LLM chat conversation.
:param content: The text content of the message.
:param role: The role of the entity sending the message.
:param name: The name of the function being called (only applicable for role FUNCTION).
:param metadata: Additional metadata associated with the message.
"""

content: str
role: ChatRole
name: Optional[str]
metadata: Dict[str, Any] = field(default_factory=dict, hash=False)

def is_from(self, role: ChatRole) -> bool:
"""
Check if the message is from a specific role.
:param role: The role to check against.
:return: True if the message is from the specified role, False otherwise.
"""
return self.role == role

@classmethod
def from_assistant(cls, content: str) -> "ChatMessage":
"""
Create a message from the assistant.
:param content: The text content of the message.
:return: A new ChatMessage instance.
"""
return cls(content, ChatRole.ASSISTANT, None)

@classmethod
def from_user(cls, content: str) -> "ChatMessage":
"""
Create a message from the user.
:param content: The text content of the message.
:return: A new ChatMessage instance.
"""
return cls(content, ChatRole.USER, None)

@classmethod
def from_system(cls, content: str) -> "ChatMessage":
"""
Create a message from the system.
:param content: The text content of the message.
:return: A new ChatMessage instance.
"""
return cls(content, ChatRole.SYSTEM, None)

@classmethod
def from_function(cls, content: str, name: str) -> "ChatMessage":
"""
Create a message from a function call.
:param content: The text content of the message.
:param name: The name of the function being called.
:return: A new ChatMessage instance.
"""
return cls(content, ChatRole.FUNCTION, name)
5 changes: 5 additions & 0 deletions releasenotes/notes/add-chat-message-c456e4603529ae85.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
preview:
- |
Introduce ChatMessage data class to facilitate structured handling and processing of message content
within LLM chat interactions.
47 changes: 47 additions & 0 deletions test/preview/dataclasses/test_chat_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest

from haystack.preview.dataclasses import ChatMessage, ChatRole


@pytest.mark.unit
def test_from_assistant_with_valid_content():
content = "Hello, how can I assist you?"
message = ChatMessage.from_assistant(content)
assert message.content == content
assert message.role == ChatRole.ASSISTANT


@pytest.mark.unit
def test_from_user_with_valid_content():
content = "I have a question."
message = ChatMessage.from_user(content)
assert message.content == content
assert message.role == ChatRole.USER


@pytest.mark.unit
def test_from_system_with_valid_content():
content = "System message."
message = ChatMessage.from_system(content)
assert message.content == content
assert message.role == ChatRole.SYSTEM


@pytest.mark.unit
def test_with_empty_content():
message = ChatMessage("", ChatRole.USER, None)
assert message.content == ""


@pytest.mark.unit
def test_with_invalid_role():
with pytest.raises(TypeError):
ChatMessage("Invalid role", "invalid_role")


@pytest.mark.unit
def test_from_function_with_empty_name():
content = "Function call"
message = ChatMessage.from_function(content, "")
assert message.content == content
assert message.name == ""

0 comments on commit dcc7e63

Please sign in to comment.