From 01fcb808754f3e53d22e17ab3a9b8b64e68e076a Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 18 Dec 2024 15:56:31 +0100 Subject: [PATCH] improvements --- .../generators/google_ai/chat/gemini.py | 48 ++++++++----------- .../tests/generators/chat/test_chat_gemini.py | 3 +- 2 files changed, 22 insertions(+), 29 deletions(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index 255625aae..ab99de20a 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -231,13 +231,9 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: raise ValueError(msg) def _message_to_part(self, message: ChatMessage) -> Part: - name = getattr(message, "name", None) - if name is None: - name = getattr(message, "_name", None) - - if message.is_from(ChatRole.ASSISTANT) and name: + if message.is_from(ChatRole.ASSISTANT) and message.name: p = Part() - p.function_call.name = name + p.function_call.name = message.name p.function_call.args = {} for k, v in json.loads(message.text).items(): p.function_call.args[k] = v @@ -248,27 +244,21 @@ def _message_to_part(self, message: ChatMessage) -> Part: return p elif message.is_from(ChatRole.FUNCTION): p = Part() - p.function_response.name = name + p.function_response.name = message.name p.function_response.response = message.text return p elif "TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL): - print("********* HERE *********") - part = Part() - part.function_response.name = message.tool_call_result.origin.tool_name - part.function_response.response = message.tool_call_result.result - print(part) + p = Part() + p.function_response.name = message.tool_call_result.origin.tool_name + p.function_response.response = message.tool_call_result.result + return p elif message.is_from(ChatRole.USER): return self._convert_part(message.text) def _message_to_content(self, message: ChatMessage) -> Content: - # support both new and legacy ChatMessage - name = getattr(message, "name", None) - if name is None: - name = getattr(message, "_name", None) - - if message.is_from(ChatRole.ASSISTANT) and name: + if message.is_from(ChatRole.ASSISTANT) and message.name: part = Part() - part.function_call.name = name + part.function_call.name = message.name part.function_call.args = {} for k, v in json.loads(message.text).items(): part.function_call.args[k] = v @@ -277,20 +267,26 @@ def _message_to_content(self, message: ChatMessage) -> Content: part.text = message.text elif message.is_from(ChatRole.FUNCTION): part = Part() - part.function_response.name = name + part.function_response.name = message.name part.function_response.response = message.text + elif message.is_from(ChatRole.USER): + part = self._convert_part(message.text) elif "TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL): - print("********* HERE *********") part = Part() part.function_response.name = message.tool_call_result.origin.tool_name - part.function_response.response = message.tool_call_result.result - print(part) + part.function_response.response = message.tool_call_result.result elif message.is_from(ChatRole.USER): part = self._convert_part(message.text) else: msg = f"Unsupported message role {message.role}" raise ValueError(msg) - role = "user" if message.is_from(ChatRole.USER) or message.is_from(ChatRole.FUNCTION) else "model" + role = ( + "user" + if message.is_from(ChatRole.USER) + or message.is_from(ChatRole.FUNCTION) + or ("TOOL" in ChatRole._member_names_ and message.is_from(ChatRole.TOOL)) + else "model" + ) return Content(parts=[part], role=role) @component.output_types(replies=List[ChatMessage]) @@ -312,11 +308,9 @@ def run( """ streaming_callback = streaming_callback or self._streaming_callback history = [self._message_to_content(m) for m in messages[:-1]] - print(history) session = self._model.start_chat(history=history) new_message = self._message_to_part(messages[-1]) - print(new_message) res = session.send_message( content=new_message, generation_config=self._generation_config, @@ -395,7 +389,7 @@ def _get_stream_response( metadata["function_call"] = part["function_call"] content = json.dumps(dict(part["function_call"]["args"])) new_message = ChatMessage.from_assistant(content, meta=metadata) - try: + try: new_message.name = part["function_call"]["name"] except AttributeError: new_message._name = part["function_call"]["name"] diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index 1ccf9e1d3..0683bf21a 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -272,9 +272,8 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 assert "function_call" in chat_message.meta assert json.loads(chat_message.text) == {"location": "Berlin", "unit": "celsius"} - weather = str(get_current_weather(**json.loads(response["replies"][0].text))) + weather = get_current_weather(**json.loads(chat_message.text)) messages += response["replies"] + [ChatMessage.from_function(weather, name="get_current_weather")] - print(messages) response = gemini_chat.run(messages=messages) assert "replies" in response assert len(response["replies"]) > 0