Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make "project-id" parameter optional during initialization #1141

Merged
merged 7 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class VertexAIGeminiChatGenerator:
from haystack.dataclasses import ChatMessage
from haystack_integrations.components.generators.google_vertex import VertexAIGeminiChatGenerator

gemini_chat = VertexAIGeminiChatGenerator(project_id=project_id)
gemini_chat = VertexAIGeminiChatGenerator()

messages = [ChatMessage.from_user("Tell me the name of a movie")]
res = gemini_chat.run(messages)
Expand All @@ -50,7 +50,7 @@ def __init__(
self,
*,
model: str = "gemini-1.5-flash",
project_id: str,
project_id: Optional[str] = None,
location: Optional[str] = None,
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None,
Expand All @@ -65,7 +65,7 @@ def __init__(
Authenticates using Google Cloud Application Default Credentials (ADCs).
For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc).

:param project_id: ID of the GCP project to use.
:param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication.
:param model: Name of the model to use. For available models, see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models.
:param location: The default location to use when making API calls, if not set uses us-central-1.
Defaults to None.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class VertexAIGeminiGenerator:
from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator


gemini = VertexAIGeminiGenerator(project_id=project_id)
gemini = VertexAIGeminiGenerator()
result = gemini.run(parts = ["What is the most interesting thing you know?"])
for answer in result["replies"]:
print(answer)
Expand All @@ -54,7 +54,7 @@ def __init__(
self,
*,
model: str = "gemini-1.5-flash",
project_id: str,
project_id: Optional[str] = None,
location: Optional[str] = None,
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None,
Expand All @@ -69,7 +69,7 @@ def __init__(
Authenticates using Google Cloud Application Default Credentials (ADCs).
For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc).

:param project_id: ID of the GCP project to use.
:param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication.
:param model: Name of the model to use. For available models, see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models.
:param location: The default location to use when making API calls, if not set uses us-central-1.
:param generation_config: The generation config to use.
Expand Down
21 changes: 11 additions & 10 deletions integrations/google_vertex/tests/chat/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,12 @@ def test_init(mock_vertexai_init, _mock_generative_model):
@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel")
def test_to_dict(_mock_vertexai_init, _mock_generative_model):

gemini = VertexAIGeminiChatGenerator(
project_id="TestID123",
)
gemini = VertexAIGeminiChatGenerator()
assert gemini.to_dict() == {
"type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator",
"init_parameters": {
"model": "gemini-1.5-flash",
"project_id": "TestID123",
"project_id": None,
"location": None,
"generation_config": None,
"safety_settings": None,
Expand Down Expand Up @@ -132,6 +130,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model):

gemini = VertexAIGeminiChatGenerator(
project_id="TestID123",
location="TestLocation",
generation_config=generation_config,
safety_settings=safety_settings,
tools=[tool],
Expand All @@ -144,7 +143,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model):
"init_parameters": {
"model": "gemini-1.5-flash",
"project_id": "TestID123",
"location": None,
"location": "TestLocation",
"generation_config": {
"temperature": 0.5,
"top_p": 0.5,
Expand Down Expand Up @@ -194,7 +193,7 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model):
{
"type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator",
"init_parameters": {
"project_id": "TestID123",
"project_id": None,
"model": "gemini-1.5-flash",
"generation_config": None,
"safety_settings": None,
Expand All @@ -205,7 +204,7 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model):
)

assert gemini._model_name == "gemini-1.5-flash"
assert gemini._project_id == "TestID123"
assert gemini._project_id is None
assert gemini._safety_settings is None
assert gemini._tools is None
assert gemini._tool_config is None
Expand All @@ -221,6 +220,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model):
"type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator",
"init_parameters": {
"project_id": "TestID123",
"location": "TestLocation",
"model": "gemini-1.5-flash",
"generation_config": {
"temperature": 0.5,
Expand Down Expand Up @@ -272,6 +272,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model):

assert gemini._model_name == "gemini-1.5-flash"
assert gemini._project_id == "TestID123"
assert gemini._location == "TestLocation"
assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}
assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])])
assert isinstance(gemini._tool_config, ToolConfig)
Expand All @@ -296,7 +297,7 @@ def test_run(mock_generative_model):
ChatMessage.from_system("You are a helpful assistant"),
ChatMessage.from_user("What's the capital of France?"),
]
gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None)
gemini = VertexAIGeminiChatGenerator()
response = gemini.run(messages=messages)

mock_model.send_message.assert_called_once()
Expand All @@ -321,7 +322,7 @@ def streaming_callback(_chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True

gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None, streaming_callback=streaming_callback)
gemini = VertexAIGeminiChatGenerator(streaming_callback=streaming_callback)
messages = [
ChatMessage.from_system("You are a helpful assistant"),
ChatMessage.from_user("What's the capital of France?"),
Expand All @@ -336,7 +337,7 @@ def test_serialization_deserialization_pipeline():
pipeline = Pipeline()
template = [ChatMessage.from_user("Translate to {{ target_language }}. Context: {{ snippet }}; Translation:")]
pipeline.add_component("prompt_builder", ChatPromptBuilder(template=template))
pipeline.add_component("gemini", VertexAIGeminiChatGenerator(project_id="TestID123"))
pipeline.add_component("gemini", VertexAIGeminiChatGenerator())
pipeline.connect("prompt_builder.prompt", "gemini.messages")

pipeline_dict = pipeline.to_dict()
Expand Down
23 changes: 13 additions & 10 deletions integrations/google_vertex/tests/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,12 @@ def test_init(mock_vertexai_init, _mock_generative_model):
@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel")
def test_to_dict(_mock_vertexai_init, _mock_generative_model):

gemini = VertexAIGeminiGenerator(
project_id="TestID123",
)
gemini = VertexAIGeminiGenerator()
assert gemini.to_dict() == {
"type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator",
"init_parameters": {
"model": "gemini-1.5-flash",
"project_id": "TestID123",
"project_id": None,
"location": None,
"generation_config": None,
"safety_settings": None,
Expand Down Expand Up @@ -120,6 +118,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model):

gemini = VertexAIGeminiGenerator(
project_id="TestID123",
location="TestLocation",
generation_config=generation_config,
safety_settings=safety_settings,
tools=[tool],
Expand All @@ -131,7 +130,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model):
"init_parameters": {
"model": "gemini-1.5-flash",
"project_id": "TestID123",
"location": None,
"location": "TestLocation",
"generation_config": {
"temperature": 0.5,
"top_p": 0.5,
Expand Down Expand Up @@ -181,7 +180,8 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model):
{
"type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator",
"init_parameters": {
"project_id": "TestID123",
"project_id": None,
"location": None,
"model": "gemini-1.5-flash",
"generation_config": None,
"safety_settings": None,
Expand All @@ -194,7 +194,8 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model):
)

assert gemini._model_name == "gemini-1.5-flash"
assert gemini._project_id == "TestID123"
assert gemini._project_id is None
assert gemini._location is None
assert gemini._safety_settings is None
assert gemini._tools is None
assert gemini._tool_config is None
Expand All @@ -210,6 +211,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model):
"type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator",
"init_parameters": {
"project_id": "TestID123",
"location": "TestLocation",
"model": "gemini-1.5-flash",
"generation_config": {
"temperature": 0.5,
Expand Down Expand Up @@ -261,6 +263,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model):

assert gemini._model_name == "gemini-1.5-flash"
assert gemini._project_id == "TestID123"
assert gemini._location == "TestLocation"
assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}
assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])])
assert isinstance(gemini._generation_config, GenerationConfig)
Expand All @@ -277,7 +280,7 @@ def test_run(mock_generative_model):
mock_model.generate_content.return_value = MagicMock()
mock_generative_model.return_value = mock_model

gemini = VertexAIGeminiGenerator(project_id="TestID123", location=None)
gemini = VertexAIGeminiGenerator()

response = gemini.run(["What's the weather like today?"])

Expand All @@ -303,7 +306,7 @@ def streaming_callback(_chunk: StreamingChunk) -> None:
nonlocal streaming_callback_called
streaming_callback_called = True

gemini = VertexAIGeminiGenerator(model="gemini-pro", project_id="TestID123", streaming_callback=streaming_callback)
gemini = VertexAIGeminiGenerator(model="gemini-pro", streaming_callback=streaming_callback)
gemini.run(["Come on, stream!"])
assert streaming_callback_called

Expand All @@ -316,7 +319,7 @@ def test_serialization_deserialization_pipeline():
pipeline = Pipeline()

pipeline.add_component("prompt_builder", PromptBuilder(template=template))
pipeline.add_component("gemini", VertexAIGeminiGenerator(project_id="TestID123"))
pipeline.add_component("gemini", VertexAIGeminiGenerator())
pipeline.connect("prompt_builder", "gemini")

pipeline_dict = pipeline.to_dict()
Expand Down
Loading