From 82f503565facb43cd08ef0d887253dfb0c7336f4 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 18 Sep 2024 14:49:54 +0200 Subject: [PATCH] Add metadata to chat responses --- .../generators/google_vertex/chat/gemini.py | 6 +- .../generators/google_vertex/chat/main.py | 45 +++++++++++++++ .../google_vertex/tests/chat/test_gemini.py | 57 +++++++++++++++++-- 3 files changed, 103 insertions(+), 5 deletions(-) create mode 100644 integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/main.py diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index e693c10f4..4b3dd0779 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -232,14 +232,18 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: replies = [] for candidate in response_body.candidates: for part in candidate.content.parts: + metadata=candidate.to_dict() + metadata.pop("content") if part._raw_part.text != "": - replies.append(ChatMessage.from_assistant(part.text)) + replies.append(ChatMessage(content=part._raw_part.text, role=ChatRole.ASSISTANT,name = None, meta=metadata)) elif part.function_call is not None: + replies.append( ChatMessage( content=dict(part.function_call.args.items()), role=ChatRole.ASSISTANT, name=part.function_call.name, + meta=metadata ) ) return replies diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/main.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/main.py new file mode 100644 index 000000000..acf8d782e --- /dev/null +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/main.py @@ -0,0 +1,45 @@ +from vertexai.generative_models import Tool, FunctionDeclaration + +get_current_weather_func = FunctionDeclaration( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, + "unit": { + "type": "string", + "enum": [ + "celsius", + "fahrenheit", + ], + }, + }, + "required": ["location"], + }, +) +tool = Tool([get_current_weather_func]) + +def get_current_weather(location: str, unit: str = "celsius"): + return {"weather": "sunny", "temperature": 21.8, "unit": unit} + +from haystack_integrations.components.generators.google_vertex import VertexAIGeminiChatGenerator + + +gemini_chat = VertexAIGeminiChatGenerator(project_id="my-project-1487737228087", tools=[tool]) +from haystack.dataclasses import ChatMessage + + +messages = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")] +res = gemini_chat.run(messages=messages) +print ("RESPONSE") +print (res) +print(res["replies"][0].content) + +weather = get_current_weather(**res["replies"][0].content) + +messages += res["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")] + +res = gemini_chat.run(messages=messages) +print (res) +print(res["replies"][0].content) diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py index 62e41d850..eb8992f48 100644 --- a/integrations/google_vertex/tests/chat/test_gemini.py +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -253,11 +253,10 @@ def test_run(mock_generative_model): mock_model.send_message.assert_called_once() assert "replies" in response - assert len(response["replies"]) == 1 + assert len(response["replies"]) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) - chat_message = response["replies"][0] - assert chat_message.content - assert chat_message.is_from(ChatRole.ASSISTANT) + @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") @@ -305,3 +304,53 @@ def test_serialization_deserialization_pipeline(): new_pipeline = Pipeline.from_dict(pipeline_dict) assert new_pipeline == pipeline + +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") +def test_function_call_and_execute(mock_generative_model): + mock_model = Mock() + mock_candidate = Mock(content=Content(parts=[Part.from_text("This is a generated response.")], role="model")) + mock_response = MagicMock(spec=GenerationResponse, candidates=[mock_candidate]) + + mock_model.send_message.return_value = mock_response + mock_model.start_chat.return_value = mock_model + mock_generative_model.return_value = mock_model + + get_current_weather_func = FunctionDeclaration( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, + "unit": { + "type": "string", + "enum": [ + "celsius", + "fahrenheit", + ], + }, + }, + "required": ["location"], + }, + ) + + def get_current_weather(location: str, unit: str = "celsius"): + return {"weather": "sunny", "temperature": 21.8, "unit": unit} + + + tool = Tool(function_declarations=[get_current_weather_func]) + messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")] + gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None, tools=[tool]) + + response = gemini.run(messages=messages) + assert "replies" in response + assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + + assert len(response["replies"]) > 0 + print (response) + + first_reply = response["replies"][0] + assert "tool_calls" in first_reply.meta + tool_calls = first_reply.meta["tool_calls"] + + \ No newline at end of file