From 09be978b1c370099b9b6b82b2392ee2afbf47a59 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 16 Oct 2024 15:59:36 +0200 Subject: [PATCH 1/5] Make project-id param optional --- .../components/generators/google_vertex/chat/gemini.py | 4 ++-- .../components/generators/google_vertex/gemini.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index f09692daf..eaa1f5af1 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -50,7 +50,7 @@ def __init__( self, *, model: str = "gemini-1.5-flash", - project_id: str, + project_id: Optional[str] = None, location: Optional[str] = None, generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, @@ -65,7 +65,7 @@ def __init__( Authenticates using Google Cloud Application Default Credentials (ADCs). For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). - :param project_id: ID of the GCP project to use. + :param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication. :param model: Name of the model to use. For available models, see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models. :param location: The default location to use when making API calls, if not set uses us-central-1. Defaults to None. diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index 2b1c1b477..3e4dcca7e 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -54,7 +54,7 @@ def __init__( self, *, model: str = "gemini-1.5-flash", - project_id: str, + project_id: Optional[str] = None, location: Optional[str] = None, generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, @@ -69,7 +69,7 @@ def __init__( Authenticates using Google Cloud Application Default Credentials (ADCs). For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). - :param project_id: ID of the GCP project to use. + :param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication. :param model: Name of the model to use. For available models, see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models. :param location: The default location to use when making API calls, if not set uses us-central-1. :param generation_config: The generation config to use. From 6888bf2f2734a23abbfe473c0bd87dd3349f236a Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 17 Oct 2024 14:19:23 +0200 Subject: [PATCH 2/5] Update tests --- .../google_vertex/tests/chat/test_gemini.py | 15 ++++++++------- integrations/google_vertex/tests/test_gemini.py | 17 ++++++++++------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py index 6b1308dab..bffea59ea 100644 --- a/integrations/google_vertex/tests/chat/test_gemini.py +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -90,14 +90,12 @@ def test_init(mock_vertexai_init, _mock_generative_model): @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") def test_to_dict(_mock_vertexai_init, _mock_generative_model): - gemini = VertexAIGeminiChatGenerator( - project_id="TestID123", - ) + gemini = VertexAIGeminiChatGenerator() assert gemini.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", "init_parameters": { "model": "gemini-1.5-flash", - "project_id": "TestID123", + "project_id": None, "location": None, "generation_config": None, "safety_settings": None, @@ -132,6 +130,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): gemini = VertexAIGeminiChatGenerator( project_id="TestID123", + location="TestLocation", generation_config=generation_config, safety_settings=safety_settings, tools=[tool], @@ -144,7 +143,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): "init_parameters": { "model": "gemini-1.5-flash", "project_id": "TestID123", - "location": None, + "location": "TestLocation", "generation_config": { "temperature": 0.5, "top_p": 0.5, @@ -194,7 +193,7 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): { "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", "init_parameters": { - "project_id": "TestID123", + "project_id": None, "model": "gemini-1.5-flash", "generation_config": None, "safety_settings": None, @@ -205,7 +204,7 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): ) assert gemini._model_name == "gemini-1.5-flash" - assert gemini._project_id == "TestID123" + assert gemini._project_id is None assert gemini._safety_settings is None assert gemini._tools is None assert gemini._tool_config is None @@ -221,6 +220,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", "init_parameters": { "project_id": "TestID123", + "location": "TestLocation", "model": "gemini-1.5-flash", "generation_config": { "temperature": 0.5, @@ -272,6 +272,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): assert gemini._model_name == "gemini-1.5-flash" assert gemini._project_id == "TestID123" + assert gemini._location == "TestLocation" assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]) assert isinstance(gemini._tool_config, ToolConfig) diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index 9ec3529d7..81446b29b 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -78,14 +78,12 @@ def test_init(mock_vertexai_init, _mock_generative_model): @patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") def test_to_dict(_mock_vertexai_init, _mock_generative_model): - gemini = VertexAIGeminiGenerator( - project_id="TestID123", - ) + gemini = VertexAIGeminiGenerator() assert gemini.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", "init_parameters": { "model": "gemini-1.5-flash", - "project_id": "TestID123", + "project_id": None, "location": None, "generation_config": None, "safety_settings": None, @@ -120,6 +118,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): gemini = VertexAIGeminiGenerator( project_id="TestID123", + location="TestLocation", generation_config=generation_config, safety_settings=safety_settings, tools=[tool], @@ -131,7 +130,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): "init_parameters": { "model": "gemini-1.5-flash", "project_id": "TestID123", - "location": None, + "location": "TestLocation", "generation_config": { "temperature": 0.5, "top_p": 0.5, @@ -181,7 +180,8 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): { "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", "init_parameters": { - "project_id": "TestID123", + "project_id": None, + "location": None, "model": "gemini-1.5-flash", "generation_config": None, "safety_settings": None, @@ -194,7 +194,8 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): ) assert gemini._model_name == "gemini-1.5-flash" - assert gemini._project_id == "TestID123" + assert gemini._project_id is None + assert gemini._location is None assert gemini._safety_settings is None assert gemini._tools is None assert gemini._tool_config is None @@ -210,6 +211,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", "init_parameters": { "project_id": "TestID123", + "location": "TestLocation", "model": "gemini-1.5-flash", "generation_config": { "temperature": 0.5, @@ -261,6 +263,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): assert gemini._model_name == "gemini-1.5-flash" assert gemini._project_id == "TestID123" + assert gemini._location == "TestLocation" assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]) assert isinstance(gemini._generation_config, GenerationConfig) From d63131d086affdd5fd2cbf1dd801f7bd02aa6689 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 17 Oct 2024 15:06:40 +0200 Subject: [PATCH 3/5] remove project-id from examples and some test --- .../components/generators/google_vertex/chat/gemini.py | 2 +- .../components/generators/google_vertex/gemini.py | 2 +- integrations/google_vertex/tests/chat/test_gemini.py | 6 +++--- integrations/google_vertex/tests/test_gemini.py | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index eaa1f5af1..c52f76dc6 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -36,7 +36,7 @@ class VertexAIGeminiChatGenerator: from haystack.dataclasses import ChatMessage from haystack_integrations.components.generators.google_vertex import VertexAIGeminiChatGenerator - gemini_chat = VertexAIGeminiChatGenerator(project_id=project_id) + gemini_chat = VertexAIGeminiChatGenerator() messages = [ChatMessage.from_user("Tell me the name of a movie")] res = gemini_chat.run(messages) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index 3e4dcca7e..737f2e668 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -32,7 +32,7 @@ class VertexAIGeminiGenerator: from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator - gemini = VertexAIGeminiGenerator(project_id=project_id) + gemini = VertexAIGeminiGenerator() result = gemini.run(parts = ["What is the most interesting thing you know?"]) for answer in result["replies"]: print(answer) diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py index bffea59ea..ed6afa9ec 100644 --- a/integrations/google_vertex/tests/chat/test_gemini.py +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -297,7 +297,7 @@ def test_run(mock_generative_model): ChatMessage.from_system("You are a helpful assistant"), ChatMessage.from_user("What's the capital of France?"), ] - gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None) + gemini = VertexAIGeminiChatGenerator() response = gemini.run(messages=messages) mock_model.send_message.assert_called_once() @@ -322,7 +322,7 @@ def streaming_callback(_chunk: StreamingChunk) -> None: nonlocal streaming_callback_called streaming_callback_called = True - gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None, streaming_callback=streaming_callback) + gemini = VertexAIGeminiChatGenerator(streaming_callback=streaming_callback) messages = [ ChatMessage.from_system("You are a helpful assistant"), ChatMessage.from_user("What's the capital of France?"), @@ -337,7 +337,7 @@ def test_serialization_deserialization_pipeline(): pipeline = Pipeline() template = [ChatMessage.from_user("Translate to {{ target_language }}. Context: {{ snippet }}; Translation:")] pipeline.add_component("prompt_builder", ChatPromptBuilder(template=template)) - pipeline.add_component("gemini", VertexAIGeminiChatGenerator(project_id="TestID123")) + pipeline.add_component("gemini", VertexAIGeminiChatGenerator()) pipeline.connect("prompt_builder.prompt", "gemini.messages") pipeline_dict = pipeline.to_dict() diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index 81446b29b..44e9bf1f9 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -280,7 +280,7 @@ def test_run(mock_generative_model): mock_model.generate_content.return_value = MagicMock() mock_generative_model.return_value = mock_model - gemini = VertexAIGeminiGenerator(project_id="TestID123", location=None) + gemini = VertexAIGeminiGenerator() response = gemini.run(["What's the weather like today?"]) @@ -306,7 +306,7 @@ def streaming_callback(_chunk: StreamingChunk) -> None: nonlocal streaming_callback_called streaming_callback_called = True - gemini = VertexAIGeminiGenerator(model="gemini-pro", project_id="TestID123", streaming_callback=streaming_callback) + gemini = VertexAIGeminiGenerator(model="gemini-pro", streaming_callback=streaming_callback) gemini.run(["Come on, stream!"]) assert streaming_callback_called @@ -319,7 +319,7 @@ def test_serialization_deserialization_pipeline(): pipeline = Pipeline() pipeline.add_component("prompt_builder", PromptBuilder(template=template)) - pipeline.add_component("gemini", VertexAIGeminiGenerator(project_id="TestID123")) + pipeline.add_component("gemini", VertexAIGeminiGenerator()) pipeline.connect("prompt_builder", "gemini") pipeline_dict = pipeline.to_dict() From c10853b06e19113aba9fefc50d5b25b10ccde795 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 17 Oct 2024 15:32:55 +0200 Subject: [PATCH 4/5] Small fix --- integrations/google_vertex/tests/chat/test_gemini.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py index ed6afa9ec..0d77bd9c6 100644 --- a/integrations/google_vertex/tests/chat/test_gemini.py +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -337,7 +337,7 @@ def test_serialization_deserialization_pipeline(): pipeline = Pipeline() template = [ChatMessage.from_user("Translate to {{ target_language }}. Context: {{ snippet }}; Translation:")] pipeline.add_component("prompt_builder", ChatPromptBuilder(template=template)) - pipeline.add_component("gemini", VertexAIGeminiChatGenerator()) + pipeline.add_component("gemini", VertexAIGeminiChatGenerator(project_id="TestID123")) pipeline.connect("prompt_builder.prompt", "gemini.messages") pipeline_dict = pipeline.to_dict() From 4b60e8cfb7bb55c8766da75d7a303c01f6294d34 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 17 Oct 2024 15:35:21 +0200 Subject: [PATCH 5/5] Fix --- integrations/google_vertex/tests/test_gemini.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index 44e9bf1f9..b3d6dd5f5 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -319,7 +319,7 @@ def test_serialization_deserialization_pipeline(): pipeline = Pipeline() pipeline.add_component("prompt_builder", PromptBuilder(template=template)) - pipeline.add_component("gemini", VertexAIGeminiGenerator()) + pipeline.add_component("gemini", VertexAIGeminiGenerator(project_id="TestID123")) pipeline.connect("prompt_builder", "gemini") pipeline_dict = pipeline.to_dict()