Skip to content

Commit

Permalink
Fix meta in chat
Browse files Browse the repository at this point in the history
  • Loading branch information
Amnah199 committed Sep 19, 2024
1 parent 82f5035 commit 1bdfdb3
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -312,16 +312,20 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess
:returns: The extracted responses.
"""
replies = []
metadata = response_body.to_dict()
[candidate.pop("content", None) for candidate in metadata["candidates"]]
for candidate in response_body.candidates:
for part in candidate.content.parts:
if part.text != "":
replies.append(ChatMessage.from_assistant(part.text))
replies.append(ChatMessage(content=part.text, role=ChatRole.ASSISTANT, name=None, meta=metadata))
elif part.function_call is not None:
metadata["function_call"] = part.function_call
replies.append(
ChatMessage(
content=dict(part.function_call.args.items()),
role=ChatRole.ASSISTANT,
name=part.function_call.name,
meta=metadata,
)
)
return replies
Expand All @@ -336,27 +340,32 @@ def _get_stream_response(
:param streaming_callback: The handler for the streaming response.
:returns: The extracted response with the content of all streaming chunks.
"""
responses: Union[List[str], List[ChatMessage]] = []
replies: Union[List[str], List[ChatMessage]] = []
metadata = stream.to_dict()

for candidate in metadata.get("candidates", []):
candidate.pop("content", None)

for chunk in stream:
for candidate in chunk.candidates:
for part in candidate.content.parts:
if part.text != "":
content = part.text
responses.append(content)
replies.append(part.text)
elif part.function_call is not None:
content = dict(part.function_call.args.items())
responses.append(
metadata["function_call"] = part.function_call
replies.append(
ChatMessage(
content=dict(part.function_call.args.items()),
role=ChatRole.ASSISTANT,
name=part.function_call.name,
meta=metadata,
)
)

streaming_callback(StreamingChunk(content=content, meta=chunk.to_dict()))
streaming_callback(StreamingChunk(content=part.text, meta=chunk.to_dict()))

if isinstance(responses[0], ChatMessage):
return responses
if isinstance(replies[0], ChatMessage):
return replies

combined_response = "".join(responses).lstrip()
combined_response = "".join(replies).lstrip()
return [ChatMessage.from_assistant(content=combined_response)]
24 changes: 12 additions & 12 deletions integrations/google_ai/tests/generators/chat/test_chat_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
get_current_weather_func = FunctionDeclaration.from_function(
get_current_weather,
descriptions={
"location": "The city and state, e.g. San Francisco, CA",
"location": "The city and state, e.g. San Francisco",
"unit": "The temperature unit of measurement, e.g. celsius or fahrenheit",
},
)
Expand All @@ -218,18 +218,22 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])

chat_message = response["replies"][0]
assert chat_message.content
assert chat_message.is_from(ChatRole.ASSISTANT)
assert "function_call" in chat_message.meta
assert chat_message.content == {"location": "Berlin", "unit": "celsius"}

weather = get_current_weather(**chat_message.content)
messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])

chat_message = response["replies"][0]
assert "function_call" not in chat_message.meta
assert chat_message.content
assert chat_message.is_from(ChatRole.ASSISTANT)


@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set")
Expand Down Expand Up @@ -257,17 +261,15 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
chat_message = response["replies"][0]
assert chat_message.is_from(ChatRole.ASSISTANT)
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])
assert streaming_callback_called

weather = get_current_weather(**chat_message.content)
weather = get_current_weather(**response["replies"][0].content)
messages += response["replies"] + [ChatMessage.from_function(content=weather, name="get_current_weather")]
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
chat_message = response["replies"][-1]
assert chat_message.is_from(ChatRole.ASSISTANT)
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])


@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set")
Expand All @@ -282,6 +284,4 @@ def test_past_conversation():
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
chat_message = response["replies"][0]
assert chat_message.content
assert chat_message.is_from(ChatRole.ASSISTANT)
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])
Original file line number Diff line number Diff line change
Expand Up @@ -231,19 +231,21 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]:
"""
replies = []
for candidate in response_body.candidates:
metadata = candidate.to_dict()
metadata.pop("content")
for part in candidate.content.parts:
metadata=candidate.to_dict()
metadata.pop("content")
if part._raw_part.text != "":
replies.append(ChatMessage(content=part._raw_part.text, role=ChatRole.ASSISTANT,name = None, meta=metadata))
replies.append(
ChatMessage(content=part._raw_part.text, role=ChatRole.ASSISTANT, name=None, meta=metadata)
)
elif part.function_call is not None:

metadata["function_call"] = part.function_call
replies.append(
ChatMessage(
content=dict(part.function_call.args.items()),
role=ChatRole.ASSISTANT,
name=part.function_call.name,
meta=metadata
meta=metadata,
)
)
return replies
Expand Down

This file was deleted.

63 changes: 5 additions & 58 deletions integrations/google_vertex/tests/chat/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,12 @@ def test_run(mock_generative_model):
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])




@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")
def test_run_with_streaming_callback(mock_generative_model):
mock_model = Mock()
mock_responses = iter(
[MagicMock(spec=GenerationResponse, text="First part"), MagicMock(spec=GenerationResponse, text="Second part")]
[MagicMock(spec=GenerationResponse, text="First part"), MagicMock(spec=GenerationResponse, text=" Second part")]
)

mock_model.send_message.return_value = mock_responses
Expand All @@ -283,13 +281,12 @@ def streaming_callback(chunk: StreamingChunk) -> None:
response = gemini.run(messages=messages)

mock_model.send_message.assert_called_once()
assert streaming_callback_called == ["First part", "Second part"]
assert streaming_callback_called == ["First part", " Second part"]
assert "replies" in response
assert len(response["replies"]) == 1
assert len(response["replies"]) > 0

chat_message = response["replies"][0]
assert chat_message.content
assert chat_message.is_from(ChatRole.ASSISTANT)
assert response["replies"][0].content == "First part Second part"
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])


def test_serialization_deserialization_pipeline():
Expand All @@ -304,53 +301,3 @@ def test_serialization_deserialization_pipeline():

new_pipeline = Pipeline.from_dict(pipeline_dict)
assert new_pipeline == pipeline

@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")
def test_function_call_and_execute(mock_generative_model):
mock_model = Mock()
mock_candidate = Mock(content=Content(parts=[Part.from_text("This is a generated response.")], role="model"))
mock_response = MagicMock(spec=GenerationResponse, candidates=[mock_candidate])

mock_model.send_message.return_value = mock_response
mock_model.start_chat.return_value = mock_model
mock_generative_model.return_value = mock_model

get_current_weather_func = FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"},
"unit": {
"type": "string",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
},
)

def get_current_weather(location: str, unit: str = "celsius"):
return {"weather": "sunny", "temperature": 21.8, "unit": unit}


tool = Tool(function_declarations=[get_current_weather_func])
messages = [ChatMessage.from_user(content="What is the temperature in celsius in Berlin?")]
gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None, tools=[tool])

response = gemini.run(messages=messages)
assert "replies" in response
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])

assert len(response["replies"]) > 0
print (response)

first_reply = response["replies"][0]
assert "tool_calls" in first_reply.meta
tool_calls = first_reply.meta["tool_calls"]


0 comments on commit 1bdfdb3

Please sign in to comment.