Skip to content

Commit

Permalink
small fixes + test
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Nov 19, 2024
1 parent 245991d commit eb3c088
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -313,18 +313,24 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess
"""
replies: List[ChatMessage] = []
metadata = response_body.to_dict()

# currently Google only supports one candidate and usage metadata reflects this
# this should be refactored when multiple candidates are supported
usage_metadata_openai_format = {}

usage_metadata = metadata.get("usage_metadata")
if usage_metadata:
usage_metadata_openai_format = {
"prompt_tokens": usage_metadata["prompt_token_count"],
"completion_tokens": usage_metadata["candidates_token_count"],
"total_tokens": usage_metadata["total_token_count"],
}

for idx, candidate in enumerate(response_body.candidates):
candidate_metadata = metadata["candidates"][idx]
candidate_metadata.pop("content", None) # we remove content from the metadata

# align openai api response
usage_metadata = metadata.get("usage_metadata")
if usage_metadata:
candidate_metadata["usage"] = {
"prompt_tokens": usage_metadata["prompt_token_count"],
"completion_tokens": usage_metadata["candidates_token_count"],
"total_tokens": usage_metadata["total_token_count"],
}
if usage_metadata_openai_format:
candidate_metadata["usage"] = usage_metadata_openai_format

for part in candidate.content.parts:
if part.text != "":
Expand Down
10 changes: 8 additions & 2 deletions integrations/google_ai/tests/generators/chat/test_chat_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,5 +295,11 @@ def test_past_conversation():
]
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])
replies = response["replies"]
assert len(replies) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in replies)

assert all("usage" in reply.meta for reply in replies)
assert all("prompt_tokens" in reply.meta["usage"] for reply in replies)
assert all("completion_tokens" in reply.meta["usage"] for reply in replies)
assert all("total_tokens" in reply.meta["usage"] for reply in replies)

0 comments on commit eb3c088

Please sign in to comment.