From 8a435d92db381b9288bab5ce254b26304fd22565 Mon Sep 17 00:00:00 2001 From: Tobias Wochinger Date: Fri, 13 Dec 2024 17:06:06 +0100 Subject: [PATCH 1/2] chore: add application name (#1245) * chore: add application name * fix parentheses to dataframe object --------- Co-authored-by: Mo Sriha <22803208+medsriha@users.noreply.github.com> --- .../snowflake/snowflake_table_retriever.py | 7 ++- .../tests/test_snowflake_table_retriever.py | 62 ++++++++++++++++++- 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py index aa6f5ff4d..3cbad3c9d 100644 --- a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py +++ b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py @@ -73,6 +73,7 @@ def __init__( db_schema: Optional[str] = None, warehouse: Optional[str] = None, login_timeout: Optional[int] = None, + application_name: Optional[str] = None, ) -> None: """ :param user: User's login. @@ -82,6 +83,7 @@ def __init__( :param db_schema: Name of the schema to use. :param warehouse: Name of the warehouse to use. :param login_timeout: Timeout in seconds for login. By default, 60 seconds. + :param application_name: Name of the application to use when connecting to Snowflake. """ self.user = user @@ -91,6 +93,7 @@ def __init__( self.db_schema = db_schema self.warehouse = warehouse self.login_timeout = login_timeout or 60 + self.application_name = application_name def to_dict(self) -> Dict[str, Any]: """ @@ -108,6 +111,7 @@ def to_dict(self) -> Dict[str, Any]: db_schema=self.db_schema, warehouse=self.warehouse, login_timeout=self.login_timeout, + application_name=self.application_name, ) @classmethod @@ -285,6 +289,7 @@ def _fetch_data( "schema": self.db_schema, "warehouse": self.warehouse, "login_timeout": self.login_timeout, + **({"application": self.application_name} if self.application_name else {}), } ) if conn is None: @@ -325,7 +330,7 @@ def run(self, query: str) -> Dict[str, Any]: if not query: logger.error("Provide a valid SQL query.") return { - "dataframe": pd.DataFrame, + "dataframe": pd.DataFrame(), "table": "", } else: diff --git a/integrations/snowflake/tests/test_snowflake_table_retriever.py b/integrations/snowflake/tests/test_snowflake_table_retriever.py index f5b8fee37..3e6e7d547 100644 --- a/integrations/snowflake/tests/test_snowflake_table_retriever.py +++ b/integrations/snowflake/tests/test_snowflake_table_retriever.py @@ -352,6 +352,64 @@ def test_run(self, mock_connect: MagicMock, snowflake_table_retriever: Snowflake assert result["dataframe"].equals(expected["dataframe"]) assert result["table"] == expected["table"] + mock_connect.assert_called_once_with( + user="test_user", + account="test_account", + password="test-api-key", + database="test_database", + schema="test_schema", + warehouse="test_warehouse", + login_timeout=30, + ) + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_run_with_application_name( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: + snowflake_table_retriever.application_name = "test_application" + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_col1 = MagicMock() + mock_col2 = MagicMock() + mock_cursor.fetchall.side_effect = [ + [("DATETIME", "ROLE_NAME", "USER", "USER_NAME", "GRANTED_BY")], # User roles + [ + ( + "DATETIME", + "SELECT", + "TABLE", + "locations", + "ROLE", + "ROLE_NAME", + "GRANT_OPTION", + "GRANTED_BY", + ) + ], + ] + mock_col1.name = "City" + mock_col2.name = "State" + mock_cursor.description = [mock_col1, mock_col2] + + mock_cursor.fetchmany.return_value = [("Chicago", "Illinois")] + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + query = "SELECT * FROM locations" + + snowflake_table_retriever.run(query=query) + + mock_connect.assert_called_once_with( + user="test_user", + account="test_account", + password="test-api-key", + database="test_database", + schema="test_schema", + warehouse="test_warehouse", + login_timeout=30, + application="test_application", + ) @pytest.fixture def mock_chat_completion(self) -> Generator: @@ -494,6 +552,7 @@ def test_to_dict_default(self, monkeypatch: MagicMock) -> None: "db_schema": "test_schema", "warehouse": "test_warehouse", "login_timeout": 30, + "application_name": None, }, } @@ -508,6 +567,7 @@ def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None: db_schema="SMALL_TOWNS", warehouse="COMPUTE_WH", login_timeout=30, + application_name="test_application", ) data = component.to_dict() @@ -529,6 +589,7 @@ def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None: "db_schema": "SMALL_TOWNS", "warehouse": "COMPUTE_WH", "login_timeout": 30, + "application_name": "test_application", }, } @@ -605,7 +666,6 @@ def test_empty_query(self, snowflake_table_retriever: SnowflakeTableRetriever) - assert result.empty def test_serialization_deserialization_pipeline(self) -> None: - pipeline = Pipeline() pipeline.add_component("snow", SnowflakeTableRetriever(user="test_user", account="test_account")) pipeline.add_component("prompt_builder", PromptBuilder(template="Display results {{ table }}")) From 0e207915ef018ec5d9e3a00617b37806060a8299 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Fri, 13 Dec 2024 16:09:01 +0000 Subject: [PATCH 2/2] Update the changelog --- integrations/snowflake/CHANGELOG.md | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/integrations/snowflake/CHANGELOG.md b/integrations/snowflake/CHANGELOG.md index 757bfb3fe..356bbcace 100644 --- a/integrations/snowflake/CHANGELOG.md +++ b/integrations/snowflake/CHANGELOG.md @@ -1,13 +1,29 @@ # Changelog +## [integrations/snowflake-v0.0.3] - 2024-12-13 + +### โš™๏ธ CI + +- Adopt uv as installer (#1142) + +### ๐Ÿงน Chores + +- Update ruff linting scripts and settings (#1105) +- Add application name (#1245) + + ## [integrations/snowflake-v0.0.2] - 2024-09-25 ### ๐Ÿš€ Features - Add Snowflake integration (#1064) -### โš™๏ธ Miscellaneous Tasks +### โš™๏ธ CI - Adding github workflow for Snowflake (#1097) +### ๐ŸŒ€ Miscellaneous + +- Docs: upd snowflake pydoc (#1102) +