Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: VertexAIGeminiGenerator - remove support for tools and change output type #1180

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
HarmBlockThreshold,
HarmCategory,
Part,
Tool,
ToolConfig,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -50,6 +48,16 @@ class VertexAIGeminiGenerator:
```
"""

def __new__(cls, *_, **kwargs):
if "tools" in kwargs or "tool_config" in kwargs:
msg = (
"VertexAIGeminiGenerator does not support `tools` and `tool_config` parameters. "
"Use VertexAIGeminiChatGenerator instead."
)
raise TypeError(msg)
return super(VertexAIGeminiGenerator, cls).__new__(cls) # noqa: UP008
# super(__class__, cls) is needed because of the component decorator

def __init__(
self,
*,
Expand All @@ -58,8 +66,6 @@ def __init__(
location: Optional[str] = None,
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None,
tools: Optional[List[Tool]] = None,
tool_config: Optional[ToolConfig] = None,
system_instruction: Optional[Union[str, ByteStream, Part]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
Expand All @@ -86,10 +92,6 @@ def __init__(
for [HarmBlockThreshold](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.HarmBlockThreshold)
and [HarmCategory](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.HarmCategory)
for more details.
:param tools: List of tools to use when generating content. See the documentation for
[Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.Tool)
the list of supported arguments.
:param tool_config: The tool config to use. See the documentation for [ToolConfig](https://cloud.google.com/vertex-ai/generative-ai/docs/reference/python/latest/vertexai.generative_models.ToolConfig)
:param system_instruction: Default system instruction to use for generating content.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
Expand All @@ -105,8 +107,6 @@ def __init__(
# model parameters
self._generation_config = generation_config
self._safety_settings = safety_settings
self._tools = tools
self._tool_config = tool_config
self._system_instruction = system_instruction
self._streaming_callback = streaming_callback

Expand All @@ -115,8 +115,6 @@ def __init__(
self._model_name,
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
tool_config=self._tool_config,
system_instruction=self._system_instruction,
)

Expand All @@ -132,18 +130,6 @@ def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, A
"stop_sequences": config._raw_generation_config.stop_sequences,
}

def _tool_config_to_dict(self, tool_config: ToolConfig) -> Dict[str, Any]:
"""Serializes the ToolConfig object into a dictionary."""

mode = tool_config._gapic_tool_config.function_calling_config.mode
allowed_function_names = tool_config._gapic_tool_config.function_calling_config.allowed_function_names
config_dict = {"function_calling_config": {"mode": mode}}

if allowed_function_names:
config_dict["function_calling_config"]["allowed_function_names"] = allowed_function_names

return config_dict

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
Expand All @@ -160,15 +146,10 @@ def to_dict(self) -> Dict[str, Any]:
location=self._location,
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
tool_config=self._tool_config,
system_instruction=self._system_instruction,
streaming_callback=callback_name,
)
if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools]
if (tool_config := data["init_parameters"].get("tool_config")) is not None:
data["init_parameters"]["tool_config"] = self._tool_config_to_dict(tool_config)

if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config)
return data
Expand All @@ -184,22 +165,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiGenerator":
Deserialized component.
"""

def _tool_config_from_dict(config_dict: Dict[str, Any]) -> ToolConfig:
"""Deserializes the ToolConfig object from a dictionary."""
function_calling_config = config_dict["function_calling_config"]
return ToolConfig(
function_calling_config=ToolConfig.FunctionCallingConfig(
mode=function_calling_config["mode"],
allowed_function_names=function_calling_config.get("allowed_function_names"),
)
)

if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools]
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config)
if (tool_config := data["init_parameters"].get("tool_config")) is not None:
data["init_parameters"]["tool_config"] = _tool_config_from_dict(tool_config)
if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None:
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
return default_from_dict(cls, data)
Expand All @@ -215,7 +182,7 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:
msg = f"Unsupported type {type(part)} for part {part}"
raise ValueError(msg)

@component.output_types(replies=List[Union[str, Dict[str, str]]])
@component.output_types(replies=List[str])
def run(
self,
parts: Variadic[Union[str, ByteStream, Part]],
Expand Down Expand Up @@ -257,12 +224,6 @@ def _get_response(self, response_body: GenerationResponse) -> List[str]:
for part in candidate.content.parts:
if part._raw_part.text != "":
replies.append(part.text)
elif part.function_call is not None:
function_call = {
"name": part.function_call.name,
"args": dict(part.function_call.args.items()),
}
replies.append(function_call)
return replies

def _get_stream_response(
Expand Down
125 changes: 9 additions & 116 deletions integrations/google_vertex/tests/test_gemini.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,17 @@
from unittest.mock import MagicMock, Mock, patch

import pytest
from haystack import Pipeline
from haystack.components.builders import PromptBuilder
from haystack.dataclasses import StreamingChunk
from vertexai.generative_models import (
FunctionDeclaration,
GenerationConfig,
HarmBlockThreshold,
HarmCategory,
Tool,
ToolConfig,
)

from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator

GET_CURRENT_WEATHER_FUNC = FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"},
"unit": {
"type": "string",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
},
)


@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init")
@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel")
Expand All @@ -48,32 +27,28 @@ def test_init(mock_vertexai_init, _mock_generative_model):
)
safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}

tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])
tool_config = ToolConfig(
function_calling_config=ToolConfig.FunctionCallingConfig(
mode=ToolConfig.FunctionCallingConfig.Mode.ANY,
allowed_function_names=["get_current_weather_func"],
)
)

gemini = VertexAIGeminiGenerator(
project_id="TestID123",
location="TestLocation",
generation_config=generation_config,
safety_settings=safety_settings,
tools=[tool],
tool_config=tool_config,
system_instruction="Please provide brief answers.",
)
mock_vertexai_init.assert_called()
assert gemini._model_name == "gemini-1.5-flash"
assert gemini._generation_config == generation_config
assert gemini._safety_settings == safety_settings
assert gemini._tools == [tool]
assert gemini._tool_config == tool_config
assert gemini._system_instruction == "Please provide brief answers."


def test_init_fails_with_tools_or_tool_config():
with pytest.raises(TypeError, match="VertexAIGeminiGenerator does not support `tools`"):
VertexAIGeminiGenerator(tools=["tool1", "tool2"])

with pytest.raises(TypeError, match="VertexAIGeminiGenerator does not support `tools`"):
VertexAIGeminiGenerator(tool_config={"custom": "config"})


@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init")
@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel")
def test_to_dict(_mock_vertexai_init, _mock_generative_model):
Expand All @@ -88,8 +63,6 @@ def test_to_dict(_mock_vertexai_init, _mock_generative_model):
"generation_config": None,
"safety_settings": None,
"streaming_callback": None,
"tools": None,
"tool_config": None,
"system_instruction": None,
},
}
Expand All @@ -108,21 +81,11 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model):
)
safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}

tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])
tool_config = ToolConfig(
function_calling_config=ToolConfig.FunctionCallingConfig(
mode=ToolConfig.FunctionCallingConfig.Mode.ANY,
allowed_function_names=["get_current_weather_func"],
)
)

gemini = VertexAIGeminiGenerator(
project_id="TestID123",
location="TestLocation",
generation_config=generation_config,
safety_settings=safety_settings,
tools=[tool],
tool_config=tool_config,
system_instruction="Please provide brief answers.",
)
assert gemini.to_dict() == {
Expand All @@ -141,34 +104,6 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model):
},
"safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH},
"streaming_callback": None,
"tools": [
{
"function_declarations": [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type_": "OBJECT",
"properties": {
"location": {
"type_": "STRING",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
"property_ordering": ["location", "unit"],
},
}
]
}
],
"tool_config": {
"function_calling_config": {
"mode": ToolConfig.FunctionCallingConfig.Mode.ANY,
"allowed_function_names": ["get_current_weather_func"],
}
},
"system_instruction": "Please provide brief answers.",
},
}
Expand All @@ -186,9 +121,7 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model):
"model": "gemini-1.5-flash",
"generation_config": None,
"safety_settings": None,
"tools": None,
"streaming_callback": None,
"tool_config": None,
"system_instruction": None,
},
}
Expand All @@ -198,8 +131,6 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model):
assert gemini._project_id is None
assert gemini._location is None
assert gemini._safety_settings is None
assert gemini._tools is None
assert gemini._tool_config is None
assert gemini._system_instruction is None
assert gemini._generation_config is None

Expand All @@ -223,40 +154,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model):
"stop_sequences": ["stop"],
},
"safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH},
"tools": [
{
"function_declarations": [
{
"name": "get_current_weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
},
"description": "Get the current weather in a given location",
}
]
}
],
"streaming_callback": None,
"tool_config": {
"function_calling_config": {
"mode": ToolConfig.FunctionCallingConfig.Mode.ANY,
"allowed_function_names": ["get_current_weather_func"],
}
},
"system_instruction": "Please provide brief answers.",
},
}
Expand All @@ -266,13 +164,8 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model):
assert gemini._project_id == "TestID123"
assert gemini._location == "TestLocation"
assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}
assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])])
assert isinstance(gemini._generation_config, GenerationConfig)
assert isinstance(gemini._tool_config, ToolConfig)
assert gemini._system_instruction == "Please provide brief answers."
assert (
gemini._tool_config._gapic_tool_config.function_calling_config.mode == ToolConfig.FunctionCallingConfig.Mode.ANY
)


@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel")
Expand Down