From e0a7b2680865249fa4e44d40f87dee1e5cf91bfa Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 27 Nov 2024 18:54:58 +0100 Subject: [PATCH] use class methods to create ChatMessage --- .../tests/test_chat_generator.py | 4 +-- .../components/generators/cohere/generator.py | 4 +-- .../tests/test_cohere_chat_generator.py | 12 ++++---- .../generators/google_ai/chat/gemini.py | 28 ++++++------------- .../generators/google_vertex/chat/gemini.py | 28 ++++++------------- .../ollama/tests/test_chat_generator.py | 27 ++++++------------ 6 files changed, 35 insertions(+), 68 deletions(-) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 185a34c8a..22594af2c 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -226,10 +226,8 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): generator.model_adapter.get_responses = MagicMock( return_value=[ - ChatMessage( + ChatMessage.from_assistant( content="Some text", - role=ChatRole.ASSISTANT, - name=None, meta={ "model": "claude-3-sonnet-20240229", "index": 0, diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py index 0eb65b368..e4eaf8670 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py @@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, List, Optional from haystack import component -from haystack.dataclasses import ChatMessage, ChatRole +from haystack.dataclasses import ChatMessage from haystack.utils import Secret from .chat.chat_generator import CohereChatGenerator @@ -64,7 +64,7 @@ def run(self, prompt: str): - `replies`: A list of replies generated by the model. - `meta`: Information about the request. """ - chat_message = ChatMessage(content=prompt, role=ChatRole.USER, name="", meta={}) + chat_message = ChatMessage.from_user(prompt) # Note we have to call super() like this because of the way components are dynamically built with the decorator results = super(CohereGenerator, self).run([chat_message]) # noqa return {"replies": [results["replies"][0].content], "meta": [results["replies"][0].meta]} diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index 175a6d14b..b7cc0534a 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -27,7 +27,7 @@ def streaming_chunk(text: str): @pytest.fixture def chat_messages(): - return [ChatMessage(content="What's the capital of France", role=ChatRole.ASSISTANT, name=None)] + return [ChatMessage.from_assistant(content="What's the capital of France")] class TestCohereChatGenerator: @@ -164,7 +164,7 @@ def test_message_to_dict(self, chat_messages): ) @pytest.mark.integration def test_live_run(self): - chat_messages = [ChatMessage(content="What's the capital of France", role=ChatRole.USER, name="", meta={})] + chat_messages = [ChatMessage.from_user(content="What's the capital of France")] component = CohereChatGenerator(generation_kwargs={"temperature": 0.8}) results = component.run(chat_messages) assert len(results["replies"]) == 1 @@ -201,9 +201,7 @@ def __call__(self, chunk: StreamingChunk) -> None: callback = Callback() component = CohereChatGenerator(streaming_callback=callback) - results = component.run( - [ChatMessage(content="What's the capital of France? answer in a word", role=ChatRole.USER, name=None)] - ) + results = component.run([ChatMessage.from_user(content="What's the capital of France? answer in a word")]) assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] @@ -224,7 +222,7 @@ def __call__(self, chunk: StreamingChunk) -> None: ) @pytest.mark.integration def test_live_run_with_connector(self): - chat_messages = [ChatMessage(content="What's the capital of France", role=ChatRole.USER, name="", meta={})] + chat_messages = [ChatMessage.from_user(content="What's the capital of France")] component = CohereChatGenerator(generation_kwargs={"temperature": 0.8}) results = component.run(chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]}) assert len(results["replies"]) == 1 @@ -249,7 +247,7 @@ def __call__(self, chunk: StreamingChunk) -> None: self.responses += chunk.content if chunk.content else "" callback = Callback() - chat_messages = [ChatMessage(content="What's the capital of France? answer in a word", role=None, name=None)] + chat_messages = [ChatMessage.from_user(content="What's the capital of France? answer in a word")] component = CohereChatGenerator(streaming_callback=callback) results = component.run(chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]}) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index dbcab619d..ef7d583be 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -334,19 +334,14 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess for part in candidate.content.parts: if part.text != "": - replies.append( - ChatMessage(content=part.text, role=ChatRole.ASSISTANT, name=None, meta=candidate_metadata) - ) + replies.append(ChatMessage.from_assistant(content=part.text, meta=candidate_metadata)) elif part.function_call: candidate_metadata["function_call"] = part.function_call - replies.append( - ChatMessage( - content=dict(part.function_call.args.items()), - role=ChatRole.ASSISTANT, - name=part.function_call.name, - meta=candidate_metadata, - ) + new_message = ChatMessage.from_assistant( + content=dict(part.function_call.args.items()), meta=candidate_metadata ) + new_message.name = part.function_call.name + replies.append(new_message) return replies def _get_stream_response( @@ -368,18 +363,13 @@ def _get_stream_response( for part in candidate["content"]["parts"]: if "text" in part and part["text"] != "": content = part["text"] - replies.append(ChatMessage(content=content, role=ChatRole.ASSISTANT, meta=metadata, name=None)) + replies.append(ChatMessage.from_assistant(content=content, meta=metadata)) elif "function_call" in part and len(part["function_call"]) > 0: metadata["function_call"] = part["function_call"] content = part["function_call"]["args"] - replies.append( - ChatMessage( - content=content, - role=ChatRole.ASSISTANT, - name=part["function_call"]["name"], - meta=metadata, - ) - ) + new_message = ChatMessage.from_assistant(content=content, meta=metadata) + new_message.name = part["function_call"]["name"] + replies.append(new_message) streaming_callback(StreamingChunk(content=content, meta=metadata)) return replies diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index c52f76dc6..c94367b41 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -279,19 +279,14 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: # Remove content from metadata metadata.pop("content", None) if part._raw_part.text != "": - replies.append( - ChatMessage(content=part._raw_part.text, role=ChatRole.ASSISTANT, name=None, meta=metadata) - ) + replies.append(ChatMessage.from_assistant(content=part._raw_part.text, meta=metadata)) elif part.function_call: metadata["function_call"] = part.function_call - replies.append( - ChatMessage( - content=dict(part.function_call.args.items()), - role=ChatRole.ASSISTANT, - name=part.function_call.name, - meta=metadata, - ) + new_message = ChatMessage.from_assistant( + content=dict(part.function_call.args.items()), meta=metadata ) + new_message.name = part.function_call.name + replies.append(new_message) return replies def _get_stream_response( @@ -313,18 +308,13 @@ def _get_stream_response( for part in candidate.content.parts: if part._raw_part.text: content = chunk.text - replies.append(ChatMessage(content, role=ChatRole.ASSISTANT, name=None, meta=metadata)) + replies.append(ChatMessage.from_assistant(content, meta=metadata)) elif part.function_call: metadata["function_call"] = part.function_call content = dict(part.function_call.args.items()) - replies.append( - ChatMessage( - content=content, - role=ChatRole.ASSISTANT, - name=part.function_call.name, - meta=metadata, - ) - ) + new_message = ChatMessage.from_assistant(content, meta=metadata) + new_message.name = part.function_call.name + replies.append(new_message) streaming_callback(StreamingChunk(content=content, meta=metadata)) return replies diff --git a/integrations/ollama/tests/test_chat_generator.py b/integrations/ollama/tests/test_chat_generator.py index b2b3fd927..0308f42ec 100644 --- a/integrations/ollama/tests/test_chat_generator.py +++ b/integrations/ollama/tests/test_chat_generator.py @@ -3,7 +3,7 @@ import pytest from haystack.components.generators.utils import print_streaming_chunk -from haystack.dataclasses import ChatMessage, ChatRole +from haystack.dataclasses import ChatMessage from ollama._types import ChatResponse, ResponseError from haystack_integrations.components.generators.ollama import OllamaChatGenerator @@ -128,16 +128,12 @@ def test_run_with_chat_history(self): chat_generator = OllamaChatGenerator() chat_history = [ - {"role": "user", "content": "What is the largest city in the United Kingdom by population?"}, - {"role": "assistant", "content": "London is the largest city in the United Kingdom by population"}, - {"role": "user", "content": "And what is the second largest?"}, + ChatMessage.from_user("What is the largest city in the United Kingdom by population?"), + ChatMessage.from_assistant("London is the largest city in the United Kingdom by population"), + ChatMessage.from_user("And what is the second largest?"), ] - chat_messages = [ - ChatMessage(role=ChatRole(message["role"]), content=message["content"], name=None) - for message in chat_history - ] - response = chat_generator.run(chat_messages) + response = chat_generator.run(chat_history) assert isinstance(response, dict) assert isinstance(response["replies"], list) @@ -159,17 +155,12 @@ def test_run_with_streaming(self): chat_generator = OllamaChatGenerator(streaming_callback=streaming_callback) chat_history = [ - {"role": "user", "content": "What is the largest city in the United Kingdom by population?"}, - {"role": "assistant", "content": "London is the largest city in the United Kingdom by population"}, - {"role": "user", "content": "And what is the second largest?"}, - ] - - chat_messages = [ - ChatMessage(role=ChatRole(message["role"]), content=message["content"], name=None) - for message in chat_history + ChatMessage.from_user("What is the largest city in the United Kingdom by population?"), + ChatMessage.from_assistant("London is the largest city in the United Kingdom by population"), + ChatMessage.from_user("And what is the second largest?"), ] - response = chat_generator.run(chat_messages) + response = chat_generator.run(chat_history) streaming_callback.assert_called()