From 1ecfbfa6d08f24b1bd24ff83b6ae6941e40ab352 Mon Sep 17 00:00:00 2001 From: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Tue, 9 Jul 2024 15:54:39 +0200 Subject: [PATCH] Fix Google AI tests failing (#885) * Fix Google AI tests failing * Fix GoogleAIGeminiChatGenerator to_dict and from_dict --- .../generators/google_ai/chat/gemini.py | 24 ++++++-- .../tests/generators/chat/test_chat_gemini.py | 61 +++++-------------- 2 files changed, 36 insertions(+), 49 deletions(-) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index d3a8299fd..8b592a184 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -2,9 +2,10 @@ from typing import Any, Dict, List, Optional, Union import google.generativeai as genai -from google.ai.generativelanguage import Content, Part, Tool +from google.ai.generativelanguage import Content, Part +from google.ai.generativelanguage import Tool as ToolProto from google.generativeai import GenerationConfig, GenerativeModel -from google.generativeai.types import HarmBlockThreshold, HarmCategory +from google.generativeai.types import HarmBlockThreshold, HarmCategory, Tool from haystack.core.component import component from haystack.core.serialization import default_from_dict, default_to_dict from haystack.dataclasses.byte_stream import ByteStream @@ -159,7 +160,14 @@ def to_dict(self) -> Dict[str, Any]: tools=self._tools, ) if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.serialize(t) for t in tools] + data["init_parameters"]["tools"] = [] + for tool in tools: + if isinstance(tool, Tool): + # There are multiple Tool types in the Google lib, one that is a protobuf class and + # another is a simple Python class. They have a similar structure but the Python class + # can't be easily serializated to a dict. We need to convert it to a protobuf class first. + tool = tool.to_proto() # noqa: PLW2901 + data["init_parameters"]["tools"].append(ToolProto.serialize(tool)) if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) if (safety_settings := data["init_parameters"].get("safety_settings")) is not None: @@ -179,7 +187,15 @@ def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiChatGenerator": deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.deserialize(t) for t in tools] + deserialized_tools = [] + for tool in tools: + # Tools are always serialized as a protobuf class, so we need to deserialize them first + # to be able to convert them to the Python class. + proto = ToolProto.deserialize(tool) + deserialized_tools.append( + Tool(function_declarations=proto.function_declarations, code_execution=proto.code_execution) + ) + data["init_parameters"]["tools"] = deserialized_tools if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = GenerationConfig(**generation_config) if (safety_settings := data["init_parameters"].get("safety_settings")) is not None: diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index 1b9ce4b1e..9b3124eab 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -2,9 +2,8 @@ from unittest.mock import patch import pytest -from google.ai.generativelanguage import FunctionDeclaration, Tool from google.generativeai import GenerationConfig, GenerativeModel -from google.generativeai.types import HarmBlockThreshold, HarmCategory +from google.generativeai.types import FunctionDeclaration, HarmBlockThreshold, HarmCategory, Tool from haystack.dataclasses.chat_message import ChatMessage from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator @@ -158,33 +157,16 @@ def test_from_dict(monkeypatch): top_k=2, ) assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - assert gemini._tools == [ - Tool( - function_declarations=[ - FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": { - "type_": "STRING", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, - ) - ] - ) - ] + assert len(gemini._tools) == 1 + assert len(gemini._tools[0].function_declarations) == 1 + assert gemini._tools[0].function_declarations[0].name == "get_current_weather" + assert gemini._tools[0].function_declarations[0].description == "Get the current weather in a given location" + assert ( + gemini._tools[0].function_declarations[0].parameters.properties["location"].description + == "The city and state, e.g. San Francisco, CA" + ) + assert gemini._tools[0].function_declarations[0].parameters.properties["unit"].enum == ["celsius", "fahrenheit"] + assert gemini._tools[0].function_declarations[0].parameters.required == ["location"] assert isinstance(gemini._model, GenerativeModel) @@ -195,22 +177,11 @@ def test_run(): def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001 return {"weather": "sunny", "temperature": 21.8, "unit": unit} - get_current_weather_func = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], + get_current_weather_func = FunctionDeclaration.from_function( + get_current_weather, + descriptions={ + "location": "The city and state, e.g. San Francisco, CA", + "unit": "The temperature unit of measurement, e.g. celsius or fahrenheit", }, )