Skip to content

Commit

Permalink
Fixed error in vertex streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
Amnah199 committed Sep 20, 2024
1 parent fee4a70 commit 1f53cd9
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.SYSTEM for reply in response["replies"])
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])


@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]:
replies.append(
ChatMessage(content=part._raw_part.text, role=ChatRole.ASSISTANT, name=None, meta=metadata)
)
elif part.function_call is not None:
elif part.function_call:
metadata["function_call"] = part.function_call
replies.append(
ChatMessage(
Expand All @@ -260,21 +260,28 @@ def _get_stream_response(
:returns: The extracted response with the content of all streaming chunks.
"""
replies = []

content: Union[str, Dict[Any, Any]] = ""
for chunk in stream:
metadata = chunk.to_dict()
streaming_chunk = StreamingChunk(content=chunk.text, meta=chunk.to_dict())
for candidate in chunk.candidates:
for part in candidate.content.parts:

if part._raw_part.text:
content = chunk.text
replies.append(ChatMessage(content, role=ChatRole.ASSISTANT, name=None, meta=metadata))
elif part.function_call:
metadata["function_call"] = part.function_call
content = dict(part.function_call.args.items())
replies.append(
ChatMessage(
content=content,
role=ChatRole.ASSISTANT,
name=part.function_call.name,
meta=metadata,
)
)
streaming_chunk = StreamingChunk(content=content, meta=chunk.to_dict())
streaming_callback(streaming_chunk)

if chunk.text != "":
replies.append(ChatMessage(chunk.text, role=ChatRole.ASSISTANT, name=None, meta=metadata))
elif chunk.function_call is not None:
metadata["function_call"] = chunk.function_call
replies.append(
ChatMessage(
content=dict(chunk.function_call.args.items()),
role=ChatRole.ASSISTANT,
name=chunk.function_call.name,
meta=metadata,
)
)
return replies
12 changes: 4 additions & 8 deletions integrations/google_vertex/tests/chat/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,30 +261,26 @@ def test_run(mock_generative_model):
def test_run_with_streaming_callback(mock_generative_model):
mock_model = Mock()
mock_responses = iter(
[MagicMock(spec=GenerationResponse, text="First part"), MagicMock(spec=GenerationResponse, text=" Second part")]
[MagicMock(spec=GenerationResponse, text="First part"), MagicMock(spec=GenerationResponse, text="Second part")]
)

mock_model.send_message.return_value = mock_responses
mock_model.start_chat.return_value = mock_model
mock_generative_model.return_value = mock_model

streaming_callback_called = []

def streaming_callback(chunk: StreamingChunk) -> None:
streaming_callback_called.append(chunk.content)
def streaming_callback(_chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True

gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None, streaming_callback=streaming_callback)
messages = [
ChatMessage.from_system("You are a helpful assistant"),
ChatMessage.from_user("What's the capital of France?"),
]
response = gemini.run(messages=messages)

mock_model.send_message.assert_called_once()
assert streaming_callback_called == ["First part", " Second part"]
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])


def test_serialization_deserialization_pipeline():
Expand Down

0 comments on commit 1f53cd9

Please sign in to comment.