From eb3c08875b351a93c6d51ab586086977c16f629e Mon Sep 17 00:00:00 2001 From: anakin87 Date: Tue, 19 Nov 2024 16:35:20 +0100 Subject: [PATCH] small fixes + test --- .../generators/google_ai/chat/gemini.py | 24 ++++++++++++------- .../tests/generators/chat/test_chat_gemini.py | 10 ++++++-- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index 2d02f77ac..dbcab619d 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -313,18 +313,24 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess """ replies: List[ChatMessage] = [] metadata = response_body.to_dict() + + # currently Google only supports one candidate and usage metadata reflects this + # this should be refactored when multiple candidates are supported + usage_metadata_openai_format = {} + + usage_metadata = metadata.get("usage_metadata") + if usage_metadata: + usage_metadata_openai_format = { + "prompt_tokens": usage_metadata["prompt_token_count"], + "completion_tokens": usage_metadata["candidates_token_count"], + "total_tokens": usage_metadata["total_token_count"], + } + for idx, candidate in enumerate(response_body.candidates): candidate_metadata = metadata["candidates"][idx] candidate_metadata.pop("content", None) # we remove content from the metadata - - # align openai api response - usage_metadata = metadata.get("usage_metadata") - if usage_metadata: - candidate_metadata["usage"] = { - "prompt_tokens": usage_metadata["prompt_token_count"], - "completion_tokens": usage_metadata["candidates_token_count"], - "total_tokens": usage_metadata["total_token_count"], - } + if usage_metadata_openai_format: + candidate_metadata["usage"] = usage_metadata_openai_format for part in candidate.content.parts: if part.text != "": diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index c4372db0d..cb42f0ff8 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -295,5 +295,11 @@ def test_past_conversation(): ] response = gemini_chat.run(messages=messages) assert "replies" in response - assert len(response["replies"]) > 0 - assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + replies = response["replies"] + assert len(replies) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in replies) + + assert all("usage" in reply.meta for reply in replies) + assert all("prompt_tokens" in reply.meta["usage"] for reply in replies) + assert all("completion_tokens" in reply.meta["usage"] for reply in replies) + assert all("total_tokens" in reply.meta["usage"] for reply in replies)