From 7483bc65577455b6929713ba39ee680604d9e304 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 27 Aug 2024 15:57:42 +0200 Subject: [PATCH 1/4] Replace system roles in responses --- .../components/generators/google_vertex/chat/gemini.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index 21fa1f52f..d32703a35 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -236,7 +236,7 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: for candidate in response_body.candidates: for part in candidate.content.parts: if part._raw_part.text != "": - replies.append(ChatMessage.from_system(part.text)) + replies.append(ChatMessage.from_assistant(part.text)) elif part.function_call is not None: replies.append( ChatMessage( @@ -264,4 +264,4 @@ def _get_stream_response( responses.append(streaming_chunk.content) combined_response = "".join(responses).lstrip() - return [ChatMessage.from_system(content=combined_response)] + return [ChatMessage.from_assistant(content=combined_response)] From a6a648b7d5fcec60786ea1a1571ed86d785cab53 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 27 Aug 2024 16:05:08 +0200 Subject: [PATCH 2/4] Add googleai fix --- .../components/generators/google_ai/chat/gemini.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index 20e143ba7..d81dea552 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -326,7 +326,7 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess for candidate in response_body.candidates: for part in candidate.content.parts: if part.text != "": - replies.append(ChatMessage.from_system(part.text)) + replies.append(ChatMessage.from_assistant(part.text)) elif part.function_call is not None: replies.append( ChatMessage( @@ -354,4 +354,4 @@ def _get_stream_response( responses.append(content) combined_response = "".join(responses).lstrip() - return [ChatMessage.from_system(content=combined_response)] + return [ChatMessage.from_assistant(content=combined_response)] From 4c928ac44e41674e072ef3d6afe0ad5cce8134eb Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 27 Aug 2024 16:23:35 +0200 Subject: [PATCH 3/4] Update gemini.py --- .../components/generators/google_ai/chat/gemini.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index d81dea552..62c86285e 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -331,7 +331,7 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess replies.append( ChatMessage( content=dict(part.function_call.args.items()), - role=ChatRole.SYSTEM, + role=ChatRole.ASSISTANT, name=part.function_call.name, ) ) From 56106c3cedc2c3ce4e158707fa4b7f7fe7a94e0a Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 27 Aug 2024 18:56:44 +0200 Subject: [PATCH 4/4] Fix role handling --- .../amazon_bedrock/tests/test_chat_generator.py | 2 +- .../components/generators/google_ai/chat/gemini.py | 8 ++++---- .../tests/generators/chat/test_chat_gemini.py | 3 ++- .../components/generators/google_vertex/chat/gemini.py | 10 +++++----- integrations/ollama/examples/chat_generator_example.py | 2 +- 5 files changed, 13 insertions(+), 12 deletions(-) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index ed0c27401..a455d2c93 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -200,7 +200,7 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): """ Test that a long prompt is not truncated and _ensure_token_limit is not called when truncate is set to False """ - messages = [ChatMessage.from_system("What is the biggest city in United States?")] + messages = [ChatMessage.from_user("What is the biggest city in United States?")] # Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens) max_length_generated_text = 3 diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index 62c86285e..5d00be746 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -241,14 +241,14 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: raise ValueError(msg) def _message_to_part(self, message: ChatMessage) -> Part: - if message.role == ChatRole.SYSTEM and message.name: + if message.role == ChatRole.ASSISTANT and message.name: p = Part() p.function_call.name = message.name p.function_call.args = {} for k, v in message.content.items(): p.function_call.args[k] = v return p - elif message.role == ChatRole.SYSTEM: + elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}: p = Part() p.text = message.content return p @@ -261,13 +261,13 @@ def _message_to_part(self, message: ChatMessage) -> Part: return self._convert_part(message.content) def _message_to_content(self, message: ChatMessage) -> Content: - if message.role == ChatRole.SYSTEM and message.name: + if message.role == ChatRole.ASSISTANT and message.name: part = Part() part.function_call.name = message.name part.function_call.args = {} for k, v in message.content.items(): part.function_call.args[k] = v - elif message.role == ChatRole.SYSTEM: + elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}: part = Part() part.text = message.content elif message.role == ChatRole.FUNCTION: diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index 0302a3da7..04d4387ef 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -256,8 +256,9 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 def test_past_conversation(): gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro") messages = [ + ChatMessage.from_system(content="You are a knowledageable mathematician."), ChatMessage.from_user(content="What is 2+2?"), - ChatMessage.from_system(content="It's an arithmetic operation."), + ChatMessage.from_assistant(content="It's an arithmetic operation."), ChatMessage.from_user(content="Yeah, but what's the result?"), ] res = gemini_chat.run(messages=messages) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index d32703a35..7d1a15f0d 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -164,12 +164,12 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: raise ValueError(msg) def _message_to_part(self, message: ChatMessage) -> Part: - if message.role == ChatRole.SYSTEM and message.name: + if message.role == ChatRole.ASSISTANT and message.name: p = Part.from_dict({"function_call": {"name": message.name, "args": {}}}) for k, v in message.content.items(): p.function_call.args[k] = v return p - elif message.role == ChatRole.SYSTEM: + elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}: return Part.from_text(message.content) elif message.role == ChatRole.FUNCTION: return Part.from_function_response(name=message.name, response=message.content) @@ -177,11 +177,11 @@ def _message_to_part(self, message: ChatMessage) -> Part: return self._convert_part(message.content) def _message_to_content(self, message: ChatMessage) -> Content: - if message.role == ChatRole.SYSTEM and message.name: + if message.role == ChatRole.ASSISTANT and message.name: part = Part.from_dict({"function_call": {"name": message.name, "args": {}}}) for k, v in message.content.items(): part.function_call.args[k] = v - elif message.role == ChatRole.SYSTEM: + elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}: part = Part.from_text(message.content) elif message.role == ChatRole.FUNCTION: part = Part.from_function_response(name=message.name, response=message.content) @@ -241,7 +241,7 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: replies.append( ChatMessage( content=dict(part.function_call.args.items()), - role=ChatRole.SYSTEM, + role=ChatRole.ASSISTANT, name=part.function_call.name, ) ) diff --git a/integrations/ollama/examples/chat_generator_example.py b/integrations/ollama/examples/chat_generator_example.py index 834df78fb..2326ba708 100644 --- a/integrations/ollama/examples/chat_generator_example.py +++ b/integrations/ollama/examples/chat_generator_example.py @@ -11,7 +11,7 @@ messages = [ ChatMessage.from_user("What's Natural Language Processing?"), - ChatMessage.from_system( + ChatMessage.from_assistant( "Natural Language Processing (NLP) is a field of computer science and artificial " "intelligence concerned with the interaction between computers and human language" ),