diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index e859a29fd..56c84968b 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -311,17 +311,25 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess :param response_body: The response from Google AI request. :returns: The extracted responses. """ - replies = [] - for candidate in response_body.candidates: + replies: List[ChatMessage] = [] + metadata = response_body.to_dict() + for idx, candidate in enumerate(response_body.candidates): + candidate_metadata = metadata["candidates"][idx] + candidate_metadata.pop("content", None) # we remove content from the metadata + for part in candidate.content.parts: if part.text != "": - replies.append(ChatMessage.from_assistant(part.text)) - elif part.function_call is not None: + replies.append( + ChatMessage(content=part.text, role=ChatRole.ASSISTANT, name=None, meta=candidate_metadata) + ) + elif part.function_call: + candidate_metadata["function_call"] = part.function_call replies.append( ChatMessage( content=dict(part.function_call.args.items()), role=ChatRole.ASSISTANT, name=part.function_call.name, + meta=candidate_metadata, ) ) return replies @@ -336,11 +344,26 @@ def _get_stream_response( :param streaming_callback: The handler for the streaming response. :returns: The extracted response with the content of all streaming chunks. """ - responses = [] + replies: List[ChatMessage] = [] for chunk in stream: - content = chunk.text if len(chunk.parts) > 0 and "text" in chunk.parts[0] else "" - streaming_callback(StreamingChunk(content=content, meta=chunk.to_dict())) - responses.append(content) + content: Union[str, Dict[str, Any]] = "" + metadata = chunk.to_dict() # we store whole chunk as metadata in streaming calls + for candidate in chunk.candidates: + for part in candidate.content.parts: + if part.text != "": + content = part.text + replies.append(ChatMessage(content=content, role=ChatRole.ASSISTANT, meta=metadata, name=None)) + elif part.function_call is not None: + metadata["function_call"] = part.function_call + content = dict(part.function_call.args.items()) + replies.append( + ChatMessage( + content=content, + role=ChatRole.ASSISTANT, + name=part.function_call.name, + meta=metadata, + ) + ) - combined_response = "".join(responses).lstrip() - return [ChatMessage.from_assistant(content=combined_response)] + streaming_callback(StreamingChunk(content=content, meta=metadata)) + return replies diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index 35ad8db14..c4372db0d 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -5,7 +5,7 @@ from google.generativeai import GenerationConfig, GenerativeModel from google.generativeai.types import FunctionDeclaration, HarmBlockThreshold, HarmCategory, Tool from haystack.dataclasses import StreamingChunk -from haystack.dataclasses.chat_message import ChatMessage +from haystack.dataclasses.chat_message import ChatMessage, ChatRole from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator @@ -207,7 +207,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 get_current_weather_func = FunctionDeclaration.from_function( get_current_weather, descriptions={ - "location": "The city and state, e.g. San Francisco, CA", + "location": "The city, e.g. San Francisco", "unit": "The temperature unit of measurement, e.g. celsius or fahrenheit", }, ) @@ -215,14 +215,27 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 tool = Tool(function_declarations=[get_current_weather_func]) gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool]) messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")] - res = gemini_chat.run(messages=messages) - assert len(res["replies"]) > 0 + response = gemini_chat.run(messages=messages) + assert "replies" in response + assert len(response["replies"]) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) - weather = get_current_weather(**res["replies"][0].content) - messages += res["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] + # check the first response is a function call + chat_message = response["replies"][0] + assert "function_call" in chat_message.meta + assert chat_message.content == {"location": "Berlin", "unit": "celsius"} - res = gemini_chat.run(messages=messages) - assert len(res["replies"]) > 0 + weather = get_current_weather(**chat_message.content) + messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] + response = gemini_chat.run(messages=messages) + assert "replies" in response + assert len(response["replies"]) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + + # check the second response is not a function call + chat_message = response["replies"][0] + assert "function_call" not in chat_message.meta + assert isinstance(chat_message.content, str) @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") @@ -239,7 +252,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 get_current_weather_func = FunctionDeclaration.from_function( get_current_weather, descriptions={ - "location": "The city and state, e.g. San Francisco, CA", + "location": "The city, e.g. San Francisco", "unit": "The temperature unit of measurement, e.g. celsius or fahrenheit", }, ) @@ -247,10 +260,29 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 tool = Tool(function_declarations=[get_current_weather_func]) gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", tools=[tool], streaming_callback=streaming_callback) messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")] - res = gemini_chat.run(messages=messages) - assert len(res["replies"]) > 0 + response = gemini_chat.run(messages=messages) + assert "replies" in response + assert len(response["replies"]) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) assert streaming_callback_called + # check the first response is a function call + chat_message = response["replies"][0] + assert "function_call" in chat_message.meta + assert chat_message.content == {"location": "Berlin", "unit": "celsius"} + + weather = get_current_weather(**response["replies"][0].content) + messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] + response = gemini_chat.run(messages=messages) + assert "replies" in response + assert len(response["replies"]) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + + # check the second response is not a function call + chat_message = response["replies"][0] + assert "function_call" not in chat_message.meta + assert isinstance(chat_message.content, str) + @pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set") def test_past_conversation(): @@ -261,5 +293,7 @@ def test_past_conversation(): ChatMessage.from_assistant(content="It's an arithmetic operation."), ChatMessage.from_user(content="Yeah, but what's the result?"), ] - res = gemini_chat.run(messages=messages) - assert len(res["replies"]) > 0 + response = gemini_chat.run(messages=messages) + assert "replies" in response + assert len(response["replies"]) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index e693c10f4..ac4c93228 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -229,17 +229,24 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: :param response_body: The response from Vertex AI request. :returns: The extracted responses. """ - replies = [] + replies: List[ChatMessage] = [] for candidate in response_body.candidates: + metadata = candidate.to_dict() for part in candidate.content.parts: + # Remove content from metadata + metadata.pop("content", None) if part._raw_part.text != "": - replies.append(ChatMessage.from_assistant(part.text)) - elif part.function_call is not None: + replies.append( + ChatMessage(content=part._raw_part.text, role=ChatRole.ASSISTANT, name=None, meta=metadata) + ) + elif part.function_call: + metadata["function_call"] = part.function_call replies.append( ChatMessage( content=dict(part.function_call.args.items()), role=ChatRole.ASSISTANT, name=part.function_call.name, + meta=metadata, ) ) return replies @@ -254,11 +261,27 @@ def _get_stream_response( :param streaming_callback: The handler for the streaming response. :returns: The extracted response with the content of all streaming chunks. """ - responses = [] + replies: List[ChatMessage] = [] + for chunk in stream: - streaming_chunk = StreamingChunk(content=chunk.text, meta=chunk.to_dict()) - streaming_callback(streaming_chunk) - responses.append(streaming_chunk.content) + content: Union[str, Dict[str, Any]] = "" + metadata = chunk.to_dict() # we store whole chunk as metadata for streaming + for candidate in chunk.candidates: + for part in candidate.content.parts: + if part._raw_part.text: + content = chunk.text + replies.append(ChatMessage(content, role=ChatRole.ASSISTANT, name=None, meta=metadata)) + elif part.function_call: + metadata["function_call"] = part.function_call + content = dict(part.function_call.args.items()) + replies.append( + ChatMessage( + content=content, + role=ChatRole.ASSISTANT, + name=part.function_call.name, + meta=metadata, + ) + ) + streaming_callback(StreamingChunk(content=content, meta=metadata)) - combined_response = "".join(responses).lstrip() - return [ChatMessage.from_assistant(content=combined_response)] + return replies diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py index a1564b9f2..ab21008fb 100644 --- a/integrations/google_vertex/tests/chat/test_gemini.py +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -3,7 +3,7 @@ import pytest from haystack import Pipeline from haystack.components.builders import ChatPromptBuilder -from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from vertexai.generative_models import ( Content, FunctionDeclaration, @@ -249,9 +249,12 @@ def test_run(mock_generative_model): ChatMessage.from_user("What's the capital of France?"), ] gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None) - gemini.run(messages=messages) + response = gemini.run(messages=messages) mock_model.send_message.assert_called_once() + assert "replies" in response + assert len(response["replies"]) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") @@ -260,25 +263,24 @@ def test_run_with_streaming_callback(mock_generative_model): mock_responses = iter( [MagicMock(spec=GenerationResponse, text="First part"), MagicMock(spec=GenerationResponse, text="Second part")] ) - mock_model.send_message.return_value = mock_responses mock_model.start_chat.return_value = mock_model mock_generative_model.return_value = mock_model streaming_callback_called = [] - def streaming_callback(chunk: StreamingChunk) -> None: - streaming_callback_called.append(chunk.content) + def streaming_callback(_chunk: StreamingChunk) -> None: + nonlocal streaming_callback_called + streaming_callback_called = True gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None, streaming_callback=streaming_callback) messages = [ ChatMessage.from_system("You are a helpful assistant"), ChatMessage.from_user("What's the capital of France?"), ] - gemini.run(messages=messages) - + response = gemini.run(messages=messages) mock_model.send_message.assert_called_once() - assert streaming_callback_called == ["First part", "Second part"] + assert "replies" in response def test_serialization_deserialization_pipeline():