Skip to content

Commit

Permalink
chore: use class methods to create ChatMessage (#8581)
Browse files Browse the repository at this point in the history
* use class methods to build messages

* fix failing format
  • Loading branch information
anakin87 authored and Amnah199 committed Dec 3, 2024
1 parent c61600e commit 461e92f
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 94 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
Across Haystack codebase, we have replaced the use of `ChatMessage` dataclass constructor with specific
class methods (`ChatMessage.from_user`, `ChatMessage.from_assistant`, etc.).
132 changes: 45 additions & 87 deletions test/components/builders/test_answer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,77 +164,54 @@ def test_run_with_reference_pattern_set_at_runtime(self):

def test_run_with_chat_message_replies_without_pattern(self):
component = AnswerBuilder()
replies = [
ChatMessage(
content="Answer: AnswerString",
role=ChatRole.ASSISTANT,
name=None,
meta={
"model": "gpt-4o-mini",
"index": 0,
"finish_reason": "stop",
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
},
)
]
output = component.run(query="test query", replies=replies, meta=[{}])
answers = output["answers"]
assert len(answers) == 1
assert answers[0].data == "Answer: AnswerString"
assert answers[0].meta == {

message_meta = {
"model": "gpt-4o-mini",
"index": 0,
"finish_reason": "stop",
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
}
replies = [ChatMessage.from_assistant("Answer: AnswerString", meta=message_meta)]

output = component.run(query="test query", replies=replies, meta=[{}])
answers = output["answers"]
assert len(answers) == 1
assert answers[0].data == "Answer: AnswerString"
assert answers[0].meta == message_meta
assert answers[0].query == "test query"
assert answers[0].documents == []
assert isinstance(answers[0], GeneratedAnswer)

def test_run_with_chat_message_replies_with_pattern(self):
component = AnswerBuilder(pattern=r"Answer: (.*)")
replies = [
ChatMessage(
content="Answer: AnswerString",
role=ChatRole.ASSISTANT,
name=None,
meta={
"model": "gpt-4o-mini",
"index": 0,
"finish_reason": "stop",
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
},
)
]
output = component.run(query="test query", replies=replies, meta=[{}])
answers = output["answers"]
assert len(answers) == 1
assert answers[0].data == "AnswerString"
assert answers[0].meta == {

message_meta = {
"model": "gpt-4o-mini",
"index": 0,
"finish_reason": "stop",
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
}
replies = [ChatMessage.from_assistant("Answer: AnswerString", meta=message_meta)]

output = component.run(query="test query", replies=replies, meta=[{}])
answers = output["answers"]
assert len(answers) == 1
assert answers[0].data == "AnswerString"
assert answers[0].meta == message_meta
assert answers[0].query == "test query"
assert answers[0].documents == []
assert isinstance(answers[0], GeneratedAnswer)

def test_run_with_chat_message_replies_with_documents(self):
component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]")
replies = [
ChatMessage(
content="Answer: AnswerString[2]",
role=ChatRole.ASSISTANT,
name=None,
meta={
"model": "gpt-4o-mini",
"index": 0,
"finish_reason": "stop",
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
},
)
]
message_meta = {
"model": "gpt-4o-mini",
"index": 0,
"finish_reason": "stop",
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
}
replies = [ChatMessage.from_assistant("Answer: AnswerString[2]", meta=message_meta)]

output = component.run(
query="test query",
replies=replies,
Expand All @@ -244,60 +221,40 @@ def test_run_with_chat_message_replies_with_documents(self):
answers = output["answers"]
assert len(answers) == 1
assert answers[0].data == "Answer: AnswerString[2]"
assert answers[0].meta == {
"model": "gpt-4o-mini",
"index": 0,
"finish_reason": "stop",
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
}
assert answers[0].meta == message_meta
assert answers[0].query == "test query"
assert len(answers[0].documents) == 1
assert answers[0].documents[0].content == "test doc 2"

def test_run_with_chat_message_replies_with_pattern_set_at_runtime(self):
component = AnswerBuilder(pattern="unused pattern")
replies = [
ChatMessage(
content="Answer: AnswerString",
role=ChatRole.ASSISTANT,
name=None,
meta={
"model": "gpt-4o-mini",
"index": 0,
"finish_reason": "stop",
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
},
)
]
output = component.run(query="test query", replies=replies, meta=[{}], pattern=r"Answer: (.*)")
answers = output["answers"]
assert len(answers) == 1
assert answers[0].data == "AnswerString"
assert answers[0].meta == {
message_meta = {
"model": "gpt-4o-mini",
"index": 0,
"finish_reason": "stop",
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
}
replies = [ChatMessage.from_assistant("Answer: AnswerString", meta=message_meta)]

output = component.run(query="test query", replies=replies, meta=[{}], pattern=r"Answer: (.*)")
answers = output["answers"]
assert len(answers) == 1
assert answers[0].data == "AnswerString"
assert answers[0].meta == message_meta
assert answers[0].query == "test query"
assert answers[0].documents == []
assert isinstance(answers[0], GeneratedAnswer)

def test_run_with_chat_message_replies_with_meta_set_at_run_time(self):
component = AnswerBuilder()
replies = [
ChatMessage(
content="AnswerString",
role=ChatRole.ASSISTANT,
name=None,
meta={
"model": "gpt-4o-mini",
"index": 0,
"finish_reason": "stop",
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
},
)
]
message_meta = {
"model": "gpt-4o-mini",
"index": 0,
"finish_reason": "stop",
"usage": {"prompt_tokens": 32, "completion_tokens": 153, "total_tokens": 185},
}
replies = [ChatMessage.from_assistant("AnswerString", meta=message_meta)]

output = component.run(query="test query", replies=replies, meta=[{"test": "meta"}])
answers = output["answers"]
assert len(answers) == 1
Expand All @@ -315,8 +272,9 @@ def test_run_with_chat_message_replies_with_meta_set_at_run_time(self):

def test_run_with_chat_message_no_meta_with_meta_set_at_run_time(self):
component = AnswerBuilder()
replies = [ChatMessage(content="AnswerString", role=ChatRole.ASSISTANT, name=None, meta={})]
replies = [ChatMessage.from_assistant("AnswerString")]
output = component.run(query="test query", replies=replies, meta=[{"test": "meta"}])

answers = output["answers"]
assert len(answers) == 1
assert answers[0].data == "AnswerString"
Expand Down
9 changes: 6 additions & 3 deletions test/components/builders/test_chat_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,14 @@ def test_run_with_meta(self):
Test that the ChatPromptBuilder correctly handles meta data.
It should render the message and copy the meta data from the original message.
"""
m = ChatMessage(content="This is a {{ variable }}", role=ChatRole.USER, name=None, meta={"test": "test"})
m = ChatMessage.from_user("This is a {{ variable }}")
m.meta["meta_field"] = "meta_value"
builder = ChatPromptBuilder(template=[m])
res = builder.run(variable="test")
res_msg = ChatMessage(content="This is a test", role=ChatRole.USER, name=None, meta={"test": "test"})
assert res == {"prompt": [res_msg]}

expected_msg = ChatMessage.from_user("This is a test")
expected_msg.meta["meta_field"] = "meta_value"
assert res == {"prompt": [expected_msg]}

def test_run_with_invalid_template(self):
builder = ChatPromptBuilder()
Expand Down
8 changes: 4 additions & 4 deletions test/dataclasses/test_chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ def test_to_dict():


def test_from_dict():
assert ChatMessage.from_dict(data={"content": "text", "role": "user", "name": None}) == ChatMessage(
content="text", role=ChatRole("user"), name=None, meta={}
assert ChatMessage.from_dict(data={"content": "text", "role": "user", "name": None}) == ChatMessage.from_user(
"text"
)


def test_from_dict_with_meta():
assert ChatMessage.from_dict(
data={"content": "text", "role": "user", "name": None, "meta": {"something": "something"}}
) == ChatMessage(content="text", role=ChatRole("user"), name=None, meta={"something": "something"})
data={"content": "text", "role": "assistant", "name": None, "meta": {"something": "something"}}
) == ChatMessage.from_assistant("text", meta={"something": "something"})

0 comments on commit 461e92f

Please sign in to comment.