From bf8142a7b983d939aaf7c4dea428721433a32b8e Mon Sep 17 00:00:00 2001 From: anakin87 Date: Tue, 10 Dec 2024 13:08:38 +0100 Subject: [PATCH] fix vertex --- .../generators/google_vertex/chat/gemini.py | 39 ++++++++++--------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index c94367b41..2309ca718 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -1,3 +1,4 @@ +import json import logging from typing import Any, Callable, Dict, Iterable, List, Optional, Union @@ -41,7 +42,7 @@ class VertexAIGeminiChatGenerator: messages = [ChatMessage.from_user("Tell me the name of a movie")] res = gemini_chat.run(messages) - print(res["replies"][0].content) + print(res["replies"][0].text) >>> The Shawshank Redemption ``` """ @@ -209,31 +210,31 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: def _message_to_part(self, message: ChatMessage) -> Part: if message.role == ChatRole.ASSISTANT and message.name: p = Part.from_dict({"function_call": {"name": message.name, "args": {}}}) - for k, v in message.content.items(): + for k, v in json.loads(message.text).items(): p.function_call.args[k] = v return p - elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}: - return Part.from_text(message.content) - elif message.role == ChatRole.FUNCTION: - return Part.from_function_response(name=message.name, response=message.content) - elif message.role == ChatRole.USER: - return self._convert_part(message.content) + elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT): + return Part.from_text(message.text) + elif message.is_from(ChatRole.FUNCTION): + return Part.from_function_response(name=message.name, response=message.text) + elif message.is_from(ChatRole.USER): + return self._convert_part(message.text) def _message_to_content(self, message: ChatMessage) -> Content: - if message.role == ChatRole.ASSISTANT and message.name: + if message.is_from(ChatRole.ASSISTANT) and message.name: part = Part.from_dict({"function_call": {"name": message.name, "args": {}}}) - for k, v in message.content.items(): + for k, v in json.loads(message.text).items(): part.function_call.args[k] = v - elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}: - part = Part.from_text(message.content) - elif message.role == ChatRole.FUNCTION: - part = Part.from_function_response(name=message.name, response=message.content) - elif message.role == ChatRole.USER: - part = self._convert_part(message.content) + elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT): + part = Part.from_text(message.text) + elif message.is_from(ChatRole.FUNCTION): + part = Part.from_function_response(name=message.name, response=message.text) + elif message.is_from(ChatRole.USER): + part = self._convert_part(message.text) else: msg = f"Unsupported message role {message.role}" raise ValueError(msg) - role = "user" if message.role in [ChatRole.USER, ChatRole.FUNCTION] else "model" + role = "user" if message.is_from(ChatRole.USER) or message.is_from(ChatRole.FUNCTION) else "model" return Content(parts=[part], role=role) @component.output_types(replies=List[ChatMessage]) @@ -283,7 +284,7 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: elif part.function_call: metadata["function_call"] = part.function_call new_message = ChatMessage.from_assistant( - content=dict(part.function_call.args.items()), meta=metadata + content=json.dumps(dict(part.function_call.args)), meta=metadata ) new_message.name = part.function_call.name replies.append(new_message) @@ -311,7 +312,7 @@ def _get_stream_response( replies.append(ChatMessage.from_assistant(content, meta=metadata)) elif part.function_call: metadata["function_call"] = part.function_call - content = dict(part.function_call.args.items()) + content = json.dumps(dict(part.function_call.args)) new_message = ChatMessage.from_assistant(content, meta=metadata) new_message.name = part.function_call.name replies.append(new_message)