Skip to content

Commit

Permalink
test: Add tests for VertexAIChatGeminiGenerator and migrate from prev…
Browse files Browse the repository at this point in the history
…iew package in vertexai (#1042)

* Add tests for chat generator and migrate from preview package to a stable version of vertexai generative_model
  • Loading branch information
Amnah199 authored Sep 6, 2024
1 parent 12daeaf commit b897441
Show file tree
Hide file tree
Showing 4 changed files with 319 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/google_vertex.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11", "3.12"]

steps:
- name: Support longpaths
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
from haystack.utils import deserialize_callable, serialize_callable
from vertexai import init as vertexai_init
from vertexai.preview.generative_models import (
from vertexai.generative_models import (
Content,
GenerationConfig,
GenerationResponse,
Expand Down
295 changes: 295 additions & 0 deletions integrations/google_vertex/tests/chat/test_gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
from unittest.mock import MagicMock, Mock, patch

import pytest
from haystack import Pipeline
from haystack.components.builders import ChatPromptBuilder
from haystack.dataclasses import ChatMessage, StreamingChunk
from vertexai.generative_models import (
Content,
FunctionDeclaration,
GenerationConfig,
GenerationResponse,
HarmBlockThreshold,
HarmCategory,
Part,
Tool,
)

from haystack_integrations.components.generators.google_vertex import VertexAIGeminiChatGenerator

GET_CURRENT_WEATHER_FUNC = FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type_": "OBJECT",
"properties": {
"location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"},
"unit": {
"type_": "STRING",
"enum": [
"celsius",
"fahrenheit",
],
},
},
"required": ["location"],
},
)


@pytest.fixture
def chat_messages():
return [
ChatMessage.from_system("You are a helpful assistant"),
ChatMessage.from_user("What's the capital of France"),
]


@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init")
@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")
def test_init(mock_vertexai_init, _mock_generative_model):

generation_config = GenerationConfig(
candidate_count=1,
stop_sequences=["stop"],
max_output_tokens=10,
temperature=0.5,
top_p=0.5,
top_k=0.5,
)
safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}

tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])

gemini = VertexAIGeminiChatGenerator(
project_id="TestID123",
location="TestLocation",
generation_config=generation_config,
safety_settings=safety_settings,
tools=[tool],
)
mock_vertexai_init.assert_called()
assert gemini._model_name == "gemini-1.5-flash"
assert gemini._generation_config == generation_config
assert gemini._safety_settings == safety_settings
assert gemini._tools == [tool]


@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init")
@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")
def test_to_dict(_mock_vertexai_init, _mock_generative_model):

gemini = VertexAIGeminiChatGenerator(
project_id="TestID123",
)
assert gemini.to_dict() == {
"type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator",
"init_parameters": {
"model": "gemini-1.5-flash",
"project_id": "TestID123",
"location": None,
"generation_config": None,
"safety_settings": None,
"streaming_callback": None,
"tools": None,
},
}


@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init")
@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")
def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model):
generation_config = GenerationConfig(
candidate_count=1,
stop_sequences=["stop"],
max_output_tokens=10,
temperature=0.5,
top_p=0.5,
top_k=2,
)
safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}

tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])

gemini = VertexAIGeminiChatGenerator(
project_id="TestID123",
generation_config=generation_config,
safety_settings=safety_settings,
tools=[tool],
)

assert gemini.to_dict() == {
"type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator",
"init_parameters": {
"model": "gemini-1.5-flash",
"project_id": "TestID123",
"location": None,
"generation_config": {
"temperature": 0.5,
"top_p": 0.5,
"top_k": 2.0,
"candidate_count": 1,
"max_output_tokens": 10,
"stop_sequences": ["stop"],
},
"safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH},
"streaming_callback": None,
"tools": [
{
"function_declarations": [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type_": "OBJECT",
"properties": {
"location": {
"type_": "STRING",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
]
}
],
},
}


@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init")
@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")
def test_from_dict(_mock_vertexai_init, _mock_generative_model):
gemini = VertexAIGeminiChatGenerator.from_dict(
{
"type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator",
"init_parameters": {
"project_id": "TestID123",
"model": "gemini-1.5-flash",
"generation_config": None,
"safety_settings": None,
"tools": None,
"streaming_callback": None,
},
}
)

assert gemini._model_name == "gemini-1.5-flash"
assert gemini._project_id == "TestID123"
assert gemini._safety_settings is None
assert gemini._tools is None
assert gemini._generation_config is None


@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init")
@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")
def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model):
gemini = VertexAIGeminiChatGenerator.from_dict(
{
"type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator",
"init_parameters": {
"project_id": "TestID123",
"model": "gemini-1.5-flash",
"generation_config": {
"temperature": 0.5,
"top_p": 0.5,
"top_k": 0.5,
"candidate_count": 1,
"max_output_tokens": 10,
"stop_sequences": ["stop"],
},
"safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH},
"tools": [
{
"function_declarations": [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type_": "OBJECT",
"properties": {
"location": {
"type_": "STRING",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
]
}
],
"streaming_callback": None,
},
}
)

assert gemini._model_name == "gemini-1.5-flash"
assert gemini._project_id == "TestID123"
assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}
assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])])
assert isinstance(gemini._generation_config, GenerationConfig)


@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")
def test_run(mock_generative_model):
mock_model = Mock()
mock_candidate = Mock(content=Content(parts=[Part.from_text("This is a generated response.")], role="model"))
mock_response = MagicMock(spec=GenerationResponse, candidates=[mock_candidate])

mock_model.send_message.return_value = mock_response
mock_model.start_chat.return_value = mock_model
mock_generative_model.return_value = mock_model

messages = [
ChatMessage.from_system("You are a helpful assistant"),
ChatMessage.from_user("What's the capital of France?"),
]
gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None)
gemini.run(messages=messages)

mock_model.send_message.assert_called_once()


@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")
def test_run_with_streaming_callback(mock_generative_model):
mock_model = Mock()
mock_responses = iter(
[MagicMock(spec=GenerationResponse, text="First part"), MagicMock(spec=GenerationResponse, text="Second part")]
)

mock_model.send_message.return_value = mock_responses
mock_model.start_chat.return_value = mock_model
mock_generative_model.return_value = mock_model

streaming_callback_called = []

def streaming_callback(chunk: StreamingChunk) -> None:
streaming_callback_called.append(chunk.content)

gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None, streaming_callback=streaming_callback)
messages = [
ChatMessage.from_system("You are a helpful assistant"),
ChatMessage.from_user("What's the capital of France?"),
]
gemini.run(messages=messages)

mock_model.send_message.assert_called_once()
assert streaming_callback_called == ["First part", "Second part"]


def test_serialization_deserialization_pipeline():

pipeline = Pipeline()
template = [ChatMessage.from_user("Translate to {{ target_language }}. Context: {{ snippet }}; Translation:")]
pipeline.add_component("prompt_builder", ChatPromptBuilder(template=template))
pipeline.add_component("gemini", VertexAIGeminiChatGenerator(project_id="TestID123"))
pipeline.connect("prompt_builder.prompt", "gemini.messages")

pipeline_dict = pipeline.to_dict()

new_pipeline = Pipeline.from_dict(pipeline_dict)
assert new_pipeline == pipeline
25 changes: 22 additions & 3 deletions integrations/google_vertex/tests/test_gemini.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from unittest.mock import MagicMock, Mock, patch

from haystack import Pipeline
from haystack.components.builders import PromptBuilder
from haystack.dataclasses import StreamingChunk
from vertexai.preview.generative_models import (
from vertexai.generative_models import (
FunctionDeclaration,
GenerationConfig,
HarmBlockThreshold,
Expand Down Expand Up @@ -191,18 +193,18 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model):
"function_declarations": [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type_": "OBJECT",
"properties": {
"unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]},
"location": {
"type_": "STRING",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
"description": "Get the current weather in a given location",
}
]
}
Expand Down Expand Up @@ -254,3 +256,20 @@ def streaming_callback(_chunk: StreamingChunk) -> None:
gemini = VertexAIGeminiGenerator(model="gemini-pro", project_id="TestID123", streaming_callback=streaming_callback)
gemini.run(["Come on, stream!"])
assert streaming_callback_called


def test_serialization_deserialization_pipeline():
template = """
Answer the following questions:
1. What is the weather like today?
"""
pipeline = Pipeline()

pipeline.add_component("prompt_builder", PromptBuilder(template=template))
pipeline.add_component("gemini", VertexAIGeminiGenerator(project_id="TestID123"))
pipeline.connect("prompt_builder", "gemini")

pipeline_dict = pipeline.to_dict()

new_pipeline = Pipeline.from_dict(pipeline_dict)
assert new_pipeline == pipeline

0 comments on commit b897441

Please sign in to comment.