Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: update components to access ChatMessage.text instead of content #8589

Merged
merged 6 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion haystack/components/builders/answer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,12 @@ def run( # pylint: disable=too-many-positional-arguments
all_answers = []
for reply, given_metadata in zip(replies, meta):
# Extract content from ChatMessage objects if reply is a ChatMessages, else use the string as is
extracted_reply = reply.content if isinstance(reply, ChatMessage) else str(reply)
if isinstance(reply, ChatMessage):
if reply.text is None:
raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {reply}")
extracted_reply = reply.text
else:
extracted_reply = str(reply)
extracted_metadata = reply.meta if isinstance(reply, ChatMessage) else {}

extracted_metadata = {**extracted_metadata, **given_metadata}
Expand Down
9 changes: 6 additions & 3 deletions haystack/components/builders/chat_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def __init__(
for message in template:
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
# infer variables from template
ast = self._env.parse(message.content)
if message.text is None:
raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {message}")
ast = self._env.parse(message.text)
template_variables = meta.find_undeclared_variables(ast)
variables += list(template_variables)
self.variables = variables
Expand Down Expand Up @@ -192,8 +194,9 @@ def run(
for message in template:
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
self._validate_variables(set(template_variables_combined.keys()))

compiled_template = self._env.from_string(message.content)
if message.text is None:
raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {message}")
compiled_template = self._env.from_string(message.text)
rendered_content = compiled_template.render(template_variables_combined)
# deep copy the message to avoid modifying the original message
rendered_message: ChatMessage = deepcopy(message)
Expand Down
8 changes: 5 additions & 3 deletions haystack/components/connectors/openapi_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,17 @@ def _parse_message(self, message: ChatMessage) -> List[Dict[str, Any]]:
:raises ValueError: If the content is not valid JSON or lacks required fields.
"""
function_payloads = []
if message.text is None:
raise ValueError(f"The provided ChatMessage has no text.\nChatMessage: {message}")
try:
tool_calls = json.loads(message.content)
tool_calls = json.loads(message.text)
except json.JSONDecodeError:
raise ValueError("Invalid JSON content, expected OpenAI tools message.", message.content)
raise ValueError("Invalid JSON content, expected OpenAI tools message.", message.text)

for tool_call in tool_calls:
# this should never happen, but just in case do a sanity check
if "type" not in tool_call:
raise ValueError("Message payload doesn't seem to be a tool invocation descriptor", message.content)
raise ValueError("Message payload doesn't seem to be a tool invocation descriptor", message.text)

# In OpenAPIServiceConnector we know how to handle functions tools only
if tool_call["type"] == "function":
Expand Down
5 changes: 1 addition & 4 deletions haystack/components/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,7 @@ def run(
for response in completions:
self._check_finish_reason(response)

return {
"replies": [message.content for message in completions],
"meta": [message.meta for message in completions],
}
return {"replies": [message.text for message in completions], "meta": [message.meta for message in completions]}

@staticmethod
def _create_message_from_chunks(
Expand Down
5 changes: 4 additions & 1 deletion haystack/components/generators/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ def _convert_message_to_openai_format(message: ChatMessage) -> Dict[str, str]:
- `content`
- `name` (optional)
"""
openai_msg = {"role": message.role.value, "content": message.content}
if message.text is None:
raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {message}")

openai_msg = {"role": message.role.value, "content": message.text}
if message.name:
openai_msg["name"] = message.name

Expand Down
17 changes: 7 additions & 10 deletions haystack/components/validators/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,20 @@ def run(
dictionaries.
"""
last_message = messages[-1]
if not is_valid_json(last_message.content):
if last_message.text is None:
raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {last_message}")
if not is_valid_json(last_message.text):
return {
"validation_error": [
ChatMessage.from_user(
f"The message '{last_message.content}' is not a valid JSON object. "
f"The message '{last_message.text}' is not a valid JSON object. "
f"Please provide only a valid JSON object in string format."
f"Don't use any markdown and don't add any comment."
)
]
}

last_message_content = json.loads(last_message.content)
last_message_content = json.loads(last_message.text)
json_schema = json_schema or self.json_schema
error_template = error_template or self.error_template or self.default_error_template

Expand Down Expand Up @@ -182,16 +184,11 @@ def run(
error_template = error_template or self.default_error_template

recovery_prompt = self._construct_error_recovery_message(
error_template,
str(e),
error_path,
error_schema_path,
validation_schema,
failing_json=last_message.content,
error_template, str(e), error_path, error_schema_path, validation_schema, failing_json=last_message.text
)
return {"validation_error": [ChatMessage.from_user(recovery_prompt)]}

def _construct_error_recovery_message(
def _construct_error_recovery_message( # pylint: disable=too-many-positional-arguments
self,
error_template: str,
error_message: str,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
Replace usage of `ChatMessage.content` with `ChatMessage.text` across the codebase.
This is done in preparation for the removal of `content` in Haystack 2.9.0.
8 changes: 4 additions & 4 deletions test/components/builders/test_chat_prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def test_init(self):
]
)
assert builder.required_variables == []
assert builder.template[0].content == "This is a {{ variable }}"
assert builder.template[1].content == "This is a {{ variable2 }}"
assert builder.template[0].text == "This is a {{ variable }}"
assert builder.template[1].text == "This is a {{ variable2 }}"
assert builder._variables is None
assert builder._required_variables is None

Expand Down Expand Up @@ -62,7 +62,7 @@ def test_init_with_required_variables(self):
template=[ChatMessage.from_user("This is a {{ variable }}")], required_variables=["variable"]
)
assert builder.required_variables == ["variable"]
assert builder.template[0].content == "This is a {{ variable }}"
assert builder.template[0].text == "This is a {{ variable }}"
assert builder._variables is None
assert builder._required_variables == ["variable"]

Expand All @@ -84,7 +84,7 @@ def test_init_with_custom_variables(self):
builder = ChatPromptBuilder(template=template, variables=variables)
assert builder.required_variables == []
assert builder._variables == variables
assert builder.template[0].content == "Hello, {{ var1 }}, {{ var2 }}!"
assert builder.template[0].text == "Hello, {{ var1 }}, {{ var2 }}!"
assert builder._required_variables is None

# we have inputs that contain: template, template_variables + variables
Expand Down
4 changes: 2 additions & 2 deletions test/components/connectors/test_openapi_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def test_run_with_mix_params_request_body(self, openapi_mock, test_files_path):
# verify call went through on the wire
mock_service.call_greet.assert_called_once_with(parameters={"name": "John"}, data={"message": "Hello"})

response = json.loads(result["service_response"][0].content)
response = json.loads(result["service_response"][0].text)
assert response == "Hello, John"

@patch("haystack.components.connectors.openapi_service.OpenAPI")
Expand Down Expand Up @@ -259,7 +259,7 @@ def test_run_with_complex_types(self, openapi_mock, test_files_path):
}
)

response = json.loads(result["service_response"][0].content)
response = json.loads(result["service_response"][0].text)
assert response == {"result": "accepted"}

@patch("haystack.components.connectors.openapi_service.OpenAPI")
Expand Down
2 changes: 1 addition & 1 deletion test/components/generators/chat/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_live_run(self):
results = component.run(chat_messages)
assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
assert "Paris" in message.content
assert "Paris" in message.text
assert "gpt-4o-mini" in message.meta["model"]
assert message.meta["finish_reason"] == "stop"

Expand Down
4 changes: 2 additions & 2 deletions test/components/generators/chat/test_hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def test_run(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
assert isinstance(results["replies"][0], ChatMessage)
chat_message = results["replies"][0]
assert chat_message.is_from(ChatRole.ASSISTANT)
assert chat_message.content == "Berlin is cool"
assert chat_message.text == "Berlin is cool"

def test_run_with_custom_generation_parameters(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
generator = HuggingFaceLocalChatGenerator(model="meta-llama/Llama-2-13b-chat-hf")
Expand All @@ -216,4 +216,4 @@ def test_run_with_custom_generation_parameters(self, model_info_mock, mock_pipel
assert isinstance(results["replies"][0], ChatMessage)
chat_message = results["replies"][0]
assert chat_message.is_from(ChatRole.ASSISTANT)
assert chat_message.content == "Berlin is cool"
assert chat_message.text == "Berlin is cool"
10 changes: 5 additions & 5 deletions test/components/generators/chat/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def streaming_callback(chunk: StreamingChunk) -> None:
assert isinstance(response["replies"], list)
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk
assert "Hello" in response["replies"][0].text # see mock_chat_completion_chunk

@patch("haystack.components.generators.chat.openai.datetime")
def test_run_with_streaming_callback_in_run_method(self, mock_datetime, chat_messages, mock_chat_completion_chunk):
Expand All @@ -240,7 +240,7 @@ def streaming_callback(chunk: StreamingChunk) -> None:
assert isinstance(response["replies"], list)
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk
assert "Hello" in response["replies"][0].text # see mock_chat_completion_chunk

assert hasattr(response["replies"][0], "meta")
assert isinstance(response["replies"][0].meta, dict)
Expand Down Expand Up @@ -287,7 +287,7 @@ def test_live_run(self):
results = component.run(chat_messages)
assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
assert "Paris" in message.content
assert "Paris" in message.text
assert "gpt-4o-mini" in message.meta["model"]
assert message.meta["finish_reason"] == "stop"

Expand Down Expand Up @@ -322,7 +322,7 @@ def __call__(self, chunk: StreamingChunk) -> None:

assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
assert "Paris" in message.content
assert "Paris" in message.text

assert "gpt-4o-mini" in message.meta["model"]
assert message.meta["finish_reason"] == "stop"
Expand Down Expand Up @@ -353,7 +353,7 @@ def __call__(self, chunk: StreamingChunk) -> None:

assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
assert "Paris" in message.content
assert "Paris" in message.text

assert "gpt-4o-mini" in message.meta["model"]
assert message.meta["finish_reason"] == "stop"
Expand Down
4 changes: 2 additions & 2 deletions test/components/validators/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def run(self):
result = pipe.run(data={"schema_validator": {"json_schema": json_schema_github_compare}})
assert "validated" in result["schema_validator"]
assert len(result["schema_validator"]["validated"]) == 1
assert result["schema_validator"]["validated"][0].content == genuine_fc_message
assert result["schema_validator"]["validated"][0].text == genuine_fc_message

def test_schema_validator_in_pipeline_validation_error(self, json_schema_github_compare):
@component
Expand All @@ -202,4 +202,4 @@ def run(self):
result = pipe.run(data={"schema_validator": {"json_schema": json_schema_github_compare}})
assert "validation_error" in result["schema_validator"]
assert len(result["schema_validator"]["validation_error"]) == 1
assert "Error details" in result["schema_validator"]["validation_error"][0].content
assert "Error details" in result["schema_validator"]["validation_error"][0].text
2 changes: 1 addition & 1 deletion test/core/pipeline/features/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ def run(self, prompt_source: List[ChatMessage]):
class MessageMerger:
@component.output_types(merged_message=str)
def run(self, messages: List[ChatMessage], metadata: dict = None):
return {"merged_message": "\n".join(t.content for t in messages)}
return {"merged_message": "\n".join(t.text or "" for t in messages)}

@component
class FakeGenerator:
Expand Down
5 changes: 2 additions & 3 deletions test/dataclasses/test_chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,8 @@ def test_from_dict():


def test_from_dict_with_meta():
assert ChatMessage.from_dict(
data={"content": "text", "role": "assistant", "name": None, "meta": {"something": "something"}}
) == ChatMessage.from_assistant("text", meta={"something": "something"})
data = {"content": "text", "role": "assistant", "name": None, "meta": {"something": "something"}}
assert ChatMessage.from_dict(data) == ChatMessage.from_assistant("text", meta={"something": "something"})


def test_content_deprecation_warning(recwarn):
Expand Down