Skip to content

Commit

Permalink
Fix Google AI tests failing (#885)
Browse files Browse the repository at this point in the history
* Fix Google AI tests failing

* Fix GoogleAIGeminiChatGenerator to_dict and from_dict
  • Loading branch information
silvanocerza authored Jul 9, 2024
1 parent abfe76e commit 1ecfbfa
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from typing import Any, Dict, List, Optional, Union

import google.generativeai as genai
from google.ai.generativelanguage import Content, Part, Tool
from google.ai.generativelanguage import Content, Part
from google.ai.generativelanguage import Tool as ToolProto
from google.generativeai import GenerationConfig, GenerativeModel
from google.generativeai.types import HarmBlockThreshold, HarmCategory
from google.generativeai.types import HarmBlockThreshold, HarmCategory, Tool
from haystack.core.component import component
from haystack.core.serialization import default_from_dict, default_to_dict
from haystack.dataclasses.byte_stream import ByteStream
Expand Down Expand Up @@ -159,7 +160,14 @@ def to_dict(self) -> Dict[str, Any]:
tools=self._tools,
)
if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.serialize(t) for t in tools]
data["init_parameters"]["tools"] = []
for tool in tools:
if isinstance(tool, Tool):
# There are multiple Tool types in the Google lib, one that is a protobuf class and
# another is a simple Python class. They have a similar structure but the Python class
# can't be easily serializated to a dict. We need to convert it to a protobuf class first.
tool = tool.to_proto() # noqa: PLW2901
data["init_parameters"]["tools"].append(ToolProto.serialize(tool))
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config)
if (safety_settings := data["init_parameters"].get("safety_settings")) is not None:
Expand All @@ -179,7 +187,15 @@ def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiChatGenerator":
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])

if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.deserialize(t) for t in tools]
deserialized_tools = []
for tool in tools:
# Tools are always serialized as a protobuf class, so we need to deserialize them first
# to be able to convert them to the Python class.
proto = ToolProto.deserialize(tool)
deserialized_tools.append(
Tool(function_declarations=proto.function_declarations, code_execution=proto.code_execution)
)
data["init_parameters"]["tools"] = deserialized_tools
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = GenerationConfig(**generation_config)
if (safety_settings := data["init_parameters"].get("safety_settings")) is not None:
Expand Down
61 changes: 16 additions & 45 deletions integrations/google_ai/tests/generators/chat/test_chat_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from unittest.mock import patch

import pytest
from google.ai.generativelanguage import FunctionDeclaration, Tool
from google.generativeai import GenerationConfig, GenerativeModel
from google.generativeai.types import HarmBlockThreshold, HarmCategory
from google.generativeai.types import FunctionDeclaration, HarmBlockThreshold, HarmCategory, Tool
from haystack.dataclasses.chat_message import ChatMessage

from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator
Expand Down Expand Up @@ -158,33 +157,16 @@ def test_from_dict(monkeypatch):
top_k=2,
)
assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}
assert gemini._tools == [
Tool(
function_declarations=[
FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type_": "OBJECT",
"properties": {
"location": {
"type_": "STRING",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type_": "STRING",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
},
)
]
)
]
assert len(gemini._tools) == 1
assert len(gemini._tools[0].function_declarations) == 1
assert gemini._tools[0].function_declarations[0].name == "get_current_weather"
assert gemini._tools[0].function_declarations[0].description == "Get the current weather in a given location"
assert (
gemini._tools[0].function_declarations[0].parameters.properties["location"].description
== "The city and state, e.g. San Francisco, CA"
)
assert gemini._tools[0].function_declarations[0].parameters.properties["unit"].enum == ["celsius", "fahrenheit"]
assert gemini._tools[0].function_declarations[0].parameters.required == ["location"]
assert isinstance(gemini._model, GenerativeModel)


Expand All @@ -195,22 +177,11 @@ def test_run():
def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
return {"weather": "sunny", "temperature": 21.8, "unit": unit}

get_current_weather_func = FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type_": "OBJECT",
"properties": {
"location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"},
"unit": {
"type_": "STRING",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
get_current_weather_func = FunctionDeclaration.from_function(
get_current_weather,
descriptions={
"location": "The city and state, e.g. San Francisco, CA",
"unit": "The temperature unit of measurement, e.g. celsius or fahrenheit",
},
)

Expand Down

0 comments on commit 1ecfbfa

Please sign in to comment.