From 64e192d13532e9cf2d821fc81cc842f893124477 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Tue, 12 Nov 2024 16:29:19 +0100 Subject: [PATCH] fix: VertexAIGeminiGenerator - remove support for tools and change output type --- .../generators/google_vertex/gemini.py | 63 ++------- .../google_vertex/tests/test_gemini.py | 125 ++---------------- 2 files changed, 21 insertions(+), 167 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index 737f2e668..c9473b428 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -15,8 +15,6 @@ HarmBlockThreshold, HarmCategory, Part, - Tool, - ToolConfig, ) logger = logging.getLogger(__name__) @@ -50,6 +48,16 @@ class VertexAIGeminiGenerator: ``` """ + def __new__(cls, *_, **kwargs): + if "tools" in kwargs or "tool_config" in kwargs: + msg = ( + "VertexAIGeminiGenerator does not support `tools` and `tool_config` parameters. " + "Use VertexAIGeminiChatGenerator instead." + ) + raise TypeError(msg) + return super(VertexAIGeminiGenerator, cls).__new__(cls) # noqa: UP008 + # super(__class__, cls) is needed because of the component decorator + def __init__( self, *, @@ -58,8 +66,6 @@ def __init__( location: Optional[str] = None, generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, - tools: Optional[List[Tool]] = None, - tool_config: Optional[ToolConfig] = None, system_instruction: Optional[Union[str, ByteStream, Part]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): @@ -86,10 +92,6 @@ def __init__( for [HarmBlockThreshold](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.HarmBlockThreshold) and [HarmCategory](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.HarmCategory) for more details. - :param tools: List of tools to use when generating content. See the documentation for - [Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.Tool) - the list of supported arguments. - :param tool_config: The tool config to use. See the documentation for [ToolConfig](https://cloud.google.com/vertex-ai/generative-ai/docs/reference/python/latest/vertexai.generative_models.ToolConfig) :param system_instruction: Default system instruction to use for generating content. :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. @@ -105,8 +107,6 @@ def __init__( # model parameters self._generation_config = generation_config self._safety_settings = safety_settings - self._tools = tools - self._tool_config = tool_config self._system_instruction = system_instruction self._streaming_callback = streaming_callback @@ -115,8 +115,6 @@ def __init__( self._model_name, generation_config=self._generation_config, safety_settings=self._safety_settings, - tools=self._tools, - tool_config=self._tool_config, system_instruction=self._system_instruction, ) @@ -132,18 +130,6 @@ def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, A "stop_sequences": config._raw_generation_config.stop_sequences, } - def _tool_config_to_dict(self, tool_config: ToolConfig) -> Dict[str, Any]: - """Serializes the ToolConfig object into a dictionary.""" - - mode = tool_config._gapic_tool_config.function_calling_config.mode - allowed_function_names = tool_config._gapic_tool_config.function_calling_config.allowed_function_names - config_dict = {"function_calling_config": {"mode": mode}} - - if allowed_function_names: - config_dict["function_calling_config"]["allowed_function_names"] = allowed_function_names - - return config_dict - def to_dict(self) -> Dict[str, Any]: """ Serializes the component to a dictionary. @@ -160,15 +146,10 @@ def to_dict(self) -> Dict[str, Any]: location=self._location, generation_config=self._generation_config, safety_settings=self._safety_settings, - tools=self._tools, - tool_config=self._tool_config, system_instruction=self._system_instruction, streaming_callback=callback_name, ) - if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools] - if (tool_config := data["init_parameters"].get("tool_config")) is not None: - data["init_parameters"]["tool_config"] = self._tool_config_to_dict(tool_config) + if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) return data @@ -184,22 +165,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiGenerator": Deserialized component. """ - def _tool_config_from_dict(config_dict: Dict[str, Any]) -> ToolConfig: - """Deserializes the ToolConfig object from a dictionary.""" - function_calling_config = config_dict["function_calling_config"] - return ToolConfig( - function_calling_config=ToolConfig.FunctionCallingConfig( - mode=function_calling_config["mode"], - allowed_function_names=function_calling_config.get("allowed_function_names"), - ) - ) - - if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config) - if (tool_config := data["init_parameters"].get("tool_config")) is not None: - data["init_parameters"]["tool_config"] = _tool_config_from_dict(tool_config) if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None: data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) @@ -215,7 +182,7 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: msg = f"Unsupported type {type(part)} for part {part}" raise ValueError(msg) - @component.output_types(replies=List[Union[str, Dict[str, str]]]) + @component.output_types(replies=List[str]) def run( self, parts: Variadic[Union[str, ByteStream, Part]], @@ -257,12 +224,6 @@ def _get_response(self, response_body: GenerationResponse) -> List[str]: for part in candidate.content.parts: if part._raw_part.text != "": replies.append(part.text) - elif part.function_call is not None: - function_call = { - "name": part.function_call.name, - "args": dict(part.function_call.args.items()), - } - replies.append(function_call) return replies def _get_stream_response( diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index 277851224..ff692c6f4 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -1,38 +1,17 @@ from unittest.mock import MagicMock, Mock, patch +import pytest from haystack import Pipeline from haystack.components.builders import PromptBuilder from haystack.dataclasses import StreamingChunk from vertexai.generative_models import ( - FunctionDeclaration, GenerationConfig, HarmBlockThreshold, HarmCategory, - Tool, - ToolConfig, ) from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator -GET_CURRENT_WEATHER_FUNC = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type": "object", - "properties": { - "location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type": "string", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, -) - @patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") @patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") @@ -48,32 +27,28 @@ def test_init(mock_vertexai_init, _mock_generative_model): ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) - tool_config = ToolConfig( - function_calling_config=ToolConfig.FunctionCallingConfig( - mode=ToolConfig.FunctionCallingConfig.Mode.ANY, - allowed_function_names=["get_current_weather_func"], - ) - ) - gemini = VertexAIGeminiGenerator( project_id="TestID123", location="TestLocation", generation_config=generation_config, safety_settings=safety_settings, - tools=[tool], - tool_config=tool_config, system_instruction="Please provide brief answers.", ) mock_vertexai_init.assert_called() assert gemini._model_name == "gemini-1.5-flash" assert gemini._generation_config == generation_config assert gemini._safety_settings == safety_settings - assert gemini._tools == [tool] - assert gemini._tool_config == tool_config assert gemini._system_instruction == "Please provide brief answers." +def test_init_fails_with_tools_or_tool_config(): + with pytest.raises(TypeError, match="VertexAIGeminiGenerator does not support `tools`"): + VertexAIGeminiGenerator(tools=["tool1", "tool2"]) + + with pytest.raises(TypeError, match="VertexAIGeminiGenerator does not support `tools`"): + VertexAIGeminiGenerator(tool_config={"custom": "config"}) + + @patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") @patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") def test_to_dict(_mock_vertexai_init, _mock_generative_model): @@ -88,8 +63,6 @@ def test_to_dict(_mock_vertexai_init, _mock_generative_model): "generation_config": None, "safety_settings": None, "streaming_callback": None, - "tools": None, - "tool_config": None, "system_instruction": None, }, } @@ -108,21 +81,11 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) - tool_config = ToolConfig( - function_calling_config=ToolConfig.FunctionCallingConfig( - mode=ToolConfig.FunctionCallingConfig.Mode.ANY, - allowed_function_names=["get_current_weather_func"], - ) - ) - gemini = VertexAIGeminiGenerator( project_id="TestID123", location="TestLocation", generation_config=generation_config, safety_settings=safety_settings, - tools=[tool], - tool_config=tool_config, system_instruction="Please provide brief answers.", ) assert gemini.to_dict() == { @@ -141,34 +104,6 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): }, "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, "streaming_callback": None, - "tools": [ - { - "function_declarations": [ - { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type_": "OBJECT", - "properties": { - "location": { - "type_": "STRING", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, - }, - "required": ["location"], - "property_ordering": ["location", "unit"], - }, - } - ] - } - ], - "tool_config": { - "function_calling_config": { - "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, - "allowed_function_names": ["get_current_weather_func"], - } - }, "system_instruction": "Please provide brief answers.", }, } @@ -186,9 +121,7 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): "model": "gemini-1.5-flash", "generation_config": None, "safety_settings": None, - "tools": None, "streaming_callback": None, - "tool_config": None, "system_instruction": None, }, } @@ -198,8 +131,6 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): assert gemini._project_id is None assert gemini._location is None assert gemini._safety_settings is None - assert gemini._tools is None - assert gemini._tool_config is None assert gemini._system_instruction is None assert gemini._generation_config is None @@ -223,40 +154,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): "stop_sequences": ["stop"], }, "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, - "tools": [ - { - "function_declarations": [ - { - "name": "get_current_weather", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": { - "type": "string", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, - "description": "Get the current weather in a given location", - } - ] - } - ], "streaming_callback": None, - "tool_config": { - "function_calling_config": { - "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, - "allowed_function_names": ["get_current_weather_func"], - } - }, "system_instruction": "Please provide brief answers.", }, } @@ -266,13 +164,8 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): assert gemini._project_id == "TestID123" assert gemini._location == "TestLocation" assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]) assert isinstance(gemini._generation_config, GenerationConfig) - assert isinstance(gemini._tool_config, ToolConfig) assert gemini._system_instruction == "Please provide brief answers." - assert ( - gemini._tool_config._gapic_tool_config.function_calling_config.mode == ToolConfig.FunctionCallingConfig.Mode.ANY - ) @patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel")