From b897441335f3daad9fc9a5599287af6ddd3c5334 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Fri, 6 Sep 2024 16:25:54 +0200 Subject: [PATCH] test: Add tests for VertexAIChatGeminiGenerator and migrate from preview package in vertexai (#1042) * Add tests for chat generator and migrate from preview package to a stable version of vertexai generative_model --- .github/workflows/google_vertex.yml | 2 +- .../generators/google_vertex/chat/gemini.py | 2 +- .../google_vertex/tests/chat/test_gemini.py | 295 ++++++++++++++++++ .../google_vertex/tests/test_gemini.py | 25 +- 4 files changed, 319 insertions(+), 5 deletions(-) create mode 100644 integrations/google_vertex/tests/chat/test_gemini.py diff --git a/.github/workflows/google_vertex.yml b/.github/workflows/google_vertex.yml index 78ba5694b..34c0cf07c 100644 --- a/.github/workflows/google_vertex.yml +++ b/.github/workflows/google_vertex.yml @@ -30,7 +30,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ["3.9", "3.10"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - name: Support longpaths diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index e5ca1166d..8cdb58d2d 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -8,7 +8,7 @@ from haystack.dataclasses.chat_message import ChatMessage, ChatRole from haystack.utils import deserialize_callable, serialize_callable from vertexai import init as vertexai_init -from vertexai.preview.generative_models import ( +from vertexai.generative_models import ( Content, GenerationConfig, GenerationResponse, diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py new file mode 100644 index 000000000..a1564b9f2 --- /dev/null +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -0,0 +1,295 @@ +from unittest.mock import MagicMock, Mock, patch + +import pytest +from haystack import Pipeline +from haystack.components.builders import ChatPromptBuilder +from haystack.dataclasses import ChatMessage, StreamingChunk +from vertexai.generative_models import ( + Content, + FunctionDeclaration, + GenerationConfig, + GenerationResponse, + HarmBlockThreshold, + HarmCategory, + Part, + Tool, +) + +from haystack_integrations.components.generators.google_vertex import VertexAIGeminiChatGenerator + +GET_CURRENT_WEATHER_FUNC = FunctionDeclaration( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type_": "OBJECT", + "properties": { + "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, + "unit": { + "type_": "STRING", + "enum": [ + "celsius", + "fahrenheit", + ], + }, + }, + "required": ["location"], + }, +) + + +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_system("You are a helpful assistant"), + ChatMessage.from_user("What's the capital of France"), + ] + + +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init") +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") +def test_init(mock_vertexai_init, _mock_generative_model): + + generation_config = GenerationConfig( + candidate_count=1, + stop_sequences=["stop"], + max_output_tokens=10, + temperature=0.5, + top_p=0.5, + top_k=0.5, + ) + safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + + tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) + + gemini = VertexAIGeminiChatGenerator( + project_id="TestID123", + location="TestLocation", + generation_config=generation_config, + safety_settings=safety_settings, + tools=[tool], + ) + mock_vertexai_init.assert_called() + assert gemini._model_name == "gemini-1.5-flash" + assert gemini._generation_config == generation_config + assert gemini._safety_settings == safety_settings + assert gemini._tools == [tool] + + +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init") +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") +def test_to_dict(_mock_vertexai_init, _mock_generative_model): + + gemini = VertexAIGeminiChatGenerator( + project_id="TestID123", + ) + assert gemini.to_dict() == { + "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", + "init_parameters": { + "model": "gemini-1.5-flash", + "project_id": "TestID123", + "location": None, + "generation_config": None, + "safety_settings": None, + "streaming_callback": None, + "tools": None, + }, + } + + +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init") +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") +def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): + generation_config = GenerationConfig( + candidate_count=1, + stop_sequences=["stop"], + max_output_tokens=10, + temperature=0.5, + top_p=0.5, + top_k=2, + ) + safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + + tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) + + gemini = VertexAIGeminiChatGenerator( + project_id="TestID123", + generation_config=generation_config, + safety_settings=safety_settings, + tools=[tool], + ) + + assert gemini.to_dict() == { + "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", + "init_parameters": { + "model": "gemini-1.5-flash", + "project_id": "TestID123", + "location": None, + "generation_config": { + "temperature": 0.5, + "top_p": 0.5, + "top_k": 2.0, + "candidate_count": 1, + "max_output_tokens": 10, + "stop_sequences": ["stop"], + }, + "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, + "streaming_callback": None, + "tools": [ + { + "function_declarations": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type_": "OBJECT", + "properties": { + "location": { + "type_": "STRING", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + ] + } + ], + }, + } + + +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init") +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") +def test_from_dict(_mock_vertexai_init, _mock_generative_model): + gemini = VertexAIGeminiChatGenerator.from_dict( + { + "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", + "init_parameters": { + "project_id": "TestID123", + "model": "gemini-1.5-flash", + "generation_config": None, + "safety_settings": None, + "tools": None, + "streaming_callback": None, + }, + } + ) + + assert gemini._model_name == "gemini-1.5-flash" + assert gemini._project_id == "TestID123" + assert gemini._safety_settings is None + assert gemini._tools is None + assert gemini._generation_config is None + + +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init") +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") +def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): + gemini = VertexAIGeminiChatGenerator.from_dict( + { + "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", + "init_parameters": { + "project_id": "TestID123", + "model": "gemini-1.5-flash", + "generation_config": { + "temperature": 0.5, + "top_p": 0.5, + "top_k": 0.5, + "candidate_count": 1, + "max_output_tokens": 10, + "stop_sequences": ["stop"], + }, + "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, + "tools": [ + { + "function_declarations": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type_": "OBJECT", + "properties": { + "location": { + "type_": "STRING", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + ] + } + ], + "streaming_callback": None, + }, + } + ) + + assert gemini._model_name == "gemini-1.5-flash" + assert gemini._project_id == "TestID123" + assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]) + assert isinstance(gemini._generation_config, GenerationConfig) + + +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") +def test_run(mock_generative_model): + mock_model = Mock() + mock_candidate = Mock(content=Content(parts=[Part.from_text("This is a generated response.")], role="model")) + mock_response = MagicMock(spec=GenerationResponse, candidates=[mock_candidate]) + + mock_model.send_message.return_value = mock_response + mock_model.start_chat.return_value = mock_model + mock_generative_model.return_value = mock_model + + messages = [ + ChatMessage.from_system("You are a helpful assistant"), + ChatMessage.from_user("What's the capital of France?"), + ] + gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None) + gemini.run(messages=messages) + + mock_model.send_message.assert_called_once() + + +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") +def test_run_with_streaming_callback(mock_generative_model): + mock_model = Mock() + mock_responses = iter( + [MagicMock(spec=GenerationResponse, text="First part"), MagicMock(spec=GenerationResponse, text="Second part")] + ) + + mock_model.send_message.return_value = mock_responses + mock_model.start_chat.return_value = mock_model + mock_generative_model.return_value = mock_model + + streaming_callback_called = [] + + def streaming_callback(chunk: StreamingChunk) -> None: + streaming_callback_called.append(chunk.content) + + gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None, streaming_callback=streaming_callback) + messages = [ + ChatMessage.from_system("You are a helpful assistant"), + ChatMessage.from_user("What's the capital of France?"), + ] + gemini.run(messages=messages) + + mock_model.send_message.assert_called_once() + assert streaming_callback_called == ["First part", "Second part"] + + +def test_serialization_deserialization_pipeline(): + + pipeline = Pipeline() + template = [ChatMessage.from_user("Translate to {{ target_language }}. Context: {{ snippet }}; Translation:")] + pipeline.add_component("prompt_builder", ChatPromptBuilder(template=template)) + pipeline.add_component("gemini", VertexAIGeminiChatGenerator(project_id="TestID123")) + pipeline.connect("prompt_builder.prompt", "gemini.messages") + + pipeline_dict = pipeline.to_dict() + + new_pipeline = Pipeline.from_dict(pipeline_dict) + assert new_pipeline == pipeline diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index 8d08e0859..bb96ec409 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -1,7 +1,9 @@ from unittest.mock import MagicMock, Mock, patch +from haystack import Pipeline +from haystack.components.builders import PromptBuilder from haystack.dataclasses import StreamingChunk -from vertexai.preview.generative_models import ( +from vertexai.generative_models import ( FunctionDeclaration, GenerationConfig, HarmBlockThreshold, @@ -191,18 +193,18 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): "function_declarations": [ { "name": "get_current_weather", - "description": "Get the current weather in a given location", "parameters": { "type_": "OBJECT", "properties": { + "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, "location": { "type_": "STRING", "description": "The city and state, e.g. San Francisco, CA", }, - "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, }, "required": ["location"], }, + "description": "Get the current weather in a given location", } ] } @@ -254,3 +256,20 @@ def streaming_callback(_chunk: StreamingChunk) -> None: gemini = VertexAIGeminiGenerator(model="gemini-pro", project_id="TestID123", streaming_callback=streaming_callback) gemini.run(["Come on, stream!"]) assert streaming_callback_called + + +def test_serialization_deserialization_pipeline(): + template = """ + Answer the following questions: + 1. What is the weather like today? + """ + pipeline = Pipeline() + + pipeline.add_component("prompt_builder", PromptBuilder(template=template)) + pipeline.add_component("gemini", VertexAIGeminiGenerator(project_id="TestID123")) + pipeline.connect("prompt_builder", "gemini") + + pipeline_dict = pipeline.to_dict() + + new_pipeline = Pipeline.from_dict(pipeline_dict) + assert new_pipeline == pipeline