Skip to content

Commit

Permalink
Add metadata to chat responses
Browse files Browse the repository at this point in the history
  • Loading branch information
Amnah199 committed Sep 18, 2024
1 parent 56ab69e commit 82f5035
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,18 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]:
replies = []
for candidate in response_body.candidates:
for part in candidate.content.parts:
metadata=candidate.to_dict()
metadata.pop("content")
if part._raw_part.text != "":
replies.append(ChatMessage.from_assistant(part.text))
replies.append(ChatMessage(content=part._raw_part.text, role=ChatRole.ASSISTANT,name = None, meta=metadata))
elif part.function_call is not None:

replies.append(
ChatMessage(
content=dict(part.function_call.args.items()),
role=ChatRole.ASSISTANT,
name=part.function_call.name,
meta=metadata
)
)
return replies
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from vertexai.generative_models import Tool, FunctionDeclaration

get_current_weather_func = FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"},
"unit": {
"type": "string",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
},
)
tool = Tool([get_current_weather_func])

def get_current_weather(location: str, unit: str = "celsius"):
return {"weather": "sunny", "temperature": 21.8, "unit": unit}

from haystack_integrations.components.generators.google_vertex import VertexAIGeminiChatGenerator


gemini_chat = VertexAIGeminiChatGenerator(project_id="my-project-1487737228087", tools=[tool])
from haystack.dataclasses import ChatMessage


messages = [ChatMessage.from_user("What is the temperature in celsius in Berlin?")]
res = gemini_chat.run(messages=messages)
print ("RESPONSE")
print (res)
print(res["replies"][0].content)

weather = get_current_weather(**res["replies"][0].content)

messages += res["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")]

res = gemini_chat.run(messages=messages)
print (res)
print(res["replies"][0].content)
57 changes: 53 additions & 4 deletions integrations/google_vertex/tests/chat/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,10 @@ def test_run(mock_generative_model):

mock_model.send_message.assert_called_once()
assert "replies" in response
assert len(response["replies"]) == 1
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])

chat_message = response["replies"][0]
assert chat_message.content
assert chat_message.is_from(ChatRole.ASSISTANT)



@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")
Expand Down Expand Up @@ -305,3 +304,53 @@ def test_serialization_deserialization_pipeline():

new_pipeline = Pipeline.from_dict(pipeline_dict)
assert new_pipeline == pipeline

@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")
def test_function_call_and_execute(mock_generative_model):
mock_model = Mock()
mock_candidate = Mock(content=Content(parts=[Part.from_text("This is a generated response.")], role="model"))
mock_response = MagicMock(spec=GenerationResponse, candidates=[mock_candidate])

mock_model.send_message.return_value = mock_response
mock_model.start_chat.return_value = mock_model
mock_generative_model.return_value = mock_model

get_current_weather_func = FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"},
"unit": {
"type": "string",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
},
)

def get_current_weather(location: str, unit: str = "celsius"):
return {"weather": "sunny", "temperature": 21.8, "unit": unit}


tool = Tool(function_declarations=[get_current_weather_func])
messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")]
gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None, tools=[tool])

response = gemini.run(messages=messages)
assert "replies" in response
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])

assert len(response["replies"]) > 0
print (response)

first_reply = response["replies"][0]
assert "tool_calls" in first_reply.meta
tool_calls = first_reply.meta["tool_calls"]


0 comments on commit 82f5035

Please sign in to comment.