diff --git a/integrations/snowflake/CHANGELOG.md b/integrations/snowflake/CHANGELOG.md index c84a3b08e..0553a3f4b 100644 --- a/integrations/snowflake/CHANGELOG.md +++ b/integrations/snowflake/CHANGELOG.md @@ -1 +1 @@ -## [integrations/snowflake-v0.0.0] - 2024-09-06 \ No newline at end of file +## [integrations/snowflake-v0.0.1] - 2024-09-06 \ No newline at end of file diff --git a/integrations/snowflake/example/text2sql_example.py b/integrations/snowflake/example/text2sql_example.py index 8b47b8f6c..b85a4c677 100644 --- a/integrations/snowflake/example/text2sql_example.py +++ b/integrations/snowflake/example/text2sql_example.py @@ -5,7 +5,7 @@ from haystack.components.generators import OpenAIGenerator from haystack.utils import Secret -from haystack_integrations.components.retrievers.snowflake import SnowflakeRetriever +from haystack_integrations.components.retrievers.snowflake import SnowflakeTableRetriever load_dotenv() @@ -89,7 +89,7 @@ generation_kwargs={"temperature": 0.0, "max_tokens": 2000}, ) -snowflake = SnowflakeRetriever( +snowflake = SnowflakeTableRetriever( user="", account="", api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), diff --git a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/__init__.py b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/__init__.py index dd409ba06..294d3cce4 100644 --- a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/__init__.py +++ b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/__init__.py @@ -2,6 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -from .snowflake_retriever import SnowflakeRetriever +from .snowflake_table_retriever import SnowflakeTableRetriever -__all__ = ["SnowflakeRetriever"] +__all__ = ["SnowflakeTableRetriever"] diff --git a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_retriever.py b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py similarity index 89% rename from integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_retriever.py rename to integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py index 0aa2d5a48..aa6f5ff4d 100644 --- a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_retriever.py +++ b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py @@ -24,7 +24,7 @@ @component -class SnowflakeRetriever: +class SnowflakeTableRetriever: """ Connects to a Snowflake database to execute a SQL query. For more information, see [Snowflake documentation](https://docs.snowflake.com/en/developer-guide/python-connector/python-connector). @@ -32,7 +32,7 @@ class SnowflakeRetriever: ### Usage example: ```python - executor = SnowflakeRetriever( + executor = SnowflakeTableRetriever( user="", account="", api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), @@ -51,12 +51,12 @@ class SnowflakeRetriever: results = executor.run(query=query) - print(results["dataframe"].head(2)) + print(results["dataframe"].head(2)) # Pandas dataframe # Column 1 Column 2 # 0 Value1 Value2 # 1 Value1 Value2 - print(results["table"]) + print(results["table"]) # Markdown # | Column 1 | Column 2 | # |:----------|:----------| # | Value1 | Value2 | @@ -111,7 +111,7 @@ def to_dict(self) -> Dict[str, Any]: ) @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "SnowflakeRetriever": + def from_dict(cls, data: Dict[str, Any]) -> "SnowflakeTableRetriever": """ Deserializes the component from a dictionary. @@ -140,14 +140,13 @@ def _snowflake_connector(connect_params: Dict[str, Any]) -> Union[SnowflakeConne @staticmethod def _extract_table_names(query: str) -> list: """ - Extract table names from a SQL query using regex. + Extract table names from an SQL query using regex. The extracted table names will be checked for privilege. :param query: SQL query to extract table names from. """ - # Regular expressions to match table names in various clauses - suffix = "\\s+([a-zA-Z0-9_.]+)" + suffix = "\\s+([a-zA-Z0-9_.]+)" # Regular expressions to match table names in various clauses patterns = [ "MERGE\\s+INTO", @@ -168,7 +167,7 @@ def _extract_table_names(query: str) -> list: # Find all matches in the query matches = re.findall(pattern=combined_pattern, string=query, flags=re.IGNORECASE) - # Flatten list of tuples and remove duplication + # Flatten the list of tuples and remove duplication matches = list(set(sum(matches, ()))) # Clean and return unique table names @@ -177,7 +176,7 @@ def _extract_table_names(query: str) -> list: @staticmethod def _execute_sql_query(conn: SnowflakeConnection, query: str) -> pd.DataFrame: """ - Execute a SQL query and fetch the results. + Execute an SQL query and fetch the results. :param conn: An open connection to Snowflake. :param query: The query to execute. @@ -185,18 +184,21 @@ def _execute_sql_query(conn: SnowflakeConnection, query: str) -> pd.DataFrame: cur = conn.cursor() try: cur.execute(query) - # set a limit to MAX_SYS_ROWS rows to avoid fetching too many rows - rows = cur.fetchmany(size=MAX_SYS_ROWS) - # Convert data to a dataframe - df = pd.DataFrame(rows, columns=[desc.name for desc in cur.description]) + rows = cur.fetchmany(size=MAX_SYS_ROWS) # set a limit to avoid fetching too many rows + + df = pd.DataFrame(rows, columns=[desc.name for desc in cur.description]) # Convert data to a dataframe return df - except ProgrammingError as e: - logger.warning( - "{error_msg} Use the following ID to check the status of the query in Snowflake UI (ID: {sfqid})", - error_msg=e.msg, - sfqid=e.sfqid, - ) - return pd.DataFrame() + except Exception as e: + if isinstance(e, ProgrammingError): + logger.warning( + "{error_msg} Use the following ID to check the status of the query in Snowflake UI (ID: {sfqid})", + error_msg=e.msg, + sfqid=e.sfqid, + ) + else: + logger.warning("An unexpected error occurred: {error_msg}", error_msg=e) + + return pd.DataFrame() @staticmethod def _has_select_privilege(privileges: list, table_name: str) -> bool: @@ -213,7 +215,6 @@ def _has_select_privilege(privileges: list, table_name: str) -> bool: string=privilege[1], flags=re.IGNORECASE, ): - logger.error("User does not have `Select` privilege on the table.") return False return True @@ -304,6 +305,8 @@ def _fetch_data( user=self.user, ): df = self._execute_sql_query(conn=conn, query=query) + else: + logger.error("User does not have `Select` privilege on the table.") except Exception as e: logger.error("An unexpected error has occurred: {error}", error=e) diff --git a/integrations/snowflake/tests/test_snowflake_retriever.py b/integrations/snowflake/tests/test_snowflake_table_retriever.py similarity index 70% rename from integrations/snowflake/tests/test_snowflake_retriever.py rename to integrations/snowflake/tests/test_snowflake_table_retriever.py index c3d748086..547f7e1b1 100644 --- a/integrations/snowflake/tests/test_snowflake_retriever.py +++ b/integrations/snowflake/tests/test_snowflake_table_retriever.py @@ -12,19 +12,20 @@ from haystack import Pipeline from haystack.components.converters import OutputAdapter from haystack.components.generators import OpenAIGenerator +from haystack.components.builders import PromptBuilder from haystack.utils import Secret from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat.chat_completion import Choice from pytest import LogCaptureFixture from snowflake.connector.errors import DatabaseError, ForbiddenError, ProgrammingError -from haystack_integrations.components.retrievers.snowflake import SnowflakeRetriever +from haystack_integrations.components.retrievers.snowflake import SnowflakeTableRetriever -class TestSnowflakeRetriever: +class TestSnowflakeTableRetriever: @pytest.fixture - def snowflake_retriever(self) -> SnowflakeRetriever: - return SnowflakeRetriever( + def snowflake_table_retriever(self) -> SnowflakeTableRetriever: + return SnowflakeTableRetriever( user="test_user", account="test_account", api_key=Secret.from_token("test-api-key"), @@ -34,12 +35,16 @@ def snowflake_retriever(self) -> SnowflakeRetriever: login_timeout=30, ) - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") - def test_snowflake_connector(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_snowflake_connector( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: mock_conn = MagicMock() mock_connect.return_value = mock_conn - conn = snowflake_retriever._snowflake_connector( + conn = snowflake_table_retriever._snowflake_connector( connect_params={ "user": "test_user", "account": "test_account", @@ -62,61 +67,71 @@ def test_snowflake_connector(self, mock_connect: MagicMock, snowflake_retriever: assert conn == mock_conn - def test_query_is_empty(self, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture) -> None: + def test_query_is_empty( + self, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture + ) -> None: query = "" - result = snowflake_retriever.run(query=query) + result = snowflake_table_retriever.run(query=query) assert result["table"] == "" assert result["dataframe"].empty assert "Provide a valid SQL query" in caplog.text - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) def test_exception( - self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture ) -> None: mock_connect = mock_connect.return_value mock_connect._fetch_data.side_effect = Exception("Unknown error") query = 4 - result = snowflake_retriever.run(query=query) + result = snowflake_table_retriever.run(query=query) assert result["table"] == "" assert result["dataframe"].empty assert "An unexpected error has occurred" in caplog.text - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) def test_forbidden_error_during_connection( - self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture ) -> None: mock_connect.side_effect = ForbiddenError(msg="Forbidden error", errno=403) - result = snowflake_retriever._fetch_data(query="SELECT * FROM test_table") + result = snowflake_table_retriever._fetch_data(query="SELECT * FROM test_table") assert result.empty assert "000403: Forbidden error" in caplog.text - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) def test_programing_error_during_connection( - self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture ) -> None: mock_connect.side_effect = ProgrammingError(msg="Programming error", errno=403) - result = snowflake_retriever._fetch_data(query="SELECT * FROM test_table") + result = snowflake_table_retriever._fetch_data(query="SELECT * FROM test_table") assert result.empty assert "000403: Programming error" in caplog.text - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) def test_execute_sql_query_programming_error( - self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture ) -> None: mock_conn = MagicMock() mock_cursor = mock_conn.cursor.return_value mock_cursor.execute.side_effect = ProgrammingError(msg="Simulated programming error", sfqid="ABC-123") - result = snowflake_retriever._execute_sql_query(mock_conn, "SELECT * FROM some_table") + result = snowflake_table_retriever._execute_sql_query(mock_conn, "SELECT * FROM some_table") assert result.empty @@ -125,17 +140,21 @@ def test_execute_sql_query_programming_error( "the query in Snowflake UI (ID: ABC-123)" in caplog.text ) - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") - def test_run_connection_error(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_run_connection_error( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: mock_connect.side_effect = DatabaseError(msg="Connection error", errno=1234) query = "SELECT * FROM test_table" - result = snowflake_retriever.run(query=query) + result = snowflake_table_retriever.run(query=query) assert result["table"] == "" assert result["dataframe"].empty - def test_extract_single_table_name(self, snowflake_retriever: SnowflakeRetriever) -> None: + def test_extract_single_table_name(self, snowflake_table_retriever: SnowflakeTableRetriever) -> None: queries = [ "SELECT * FROM table_a", "SELECT name, value FROM (SELECT name, value FROM table_a) AS subquery", @@ -147,31 +166,35 @@ def test_extract_single_table_name(self, snowflake_retriever: SnowflakeRetriever "DROP TABLE table_a", ] for query in queries: - result = snowflake_retriever._extract_table_names(query) + result = snowflake_table_retriever._extract_table_names(query) assert result == ["TABLE_A"] - def test_extract_database_and_schema_from_query(self, snowflake_retriever: SnowflakeRetriever) -> None: + def test_extract_database_and_schema_from_query(self, snowflake_table_retriever: SnowflakeTableRetriever) -> None: # when database and schema are next to table name - assert snowflake_retriever._extract_table_names(query="SELECT * FROM DB.SCHEMA.TABLE_A") == [ + assert snowflake_table_retriever._extract_table_names(query="SELECT * FROM DB.SCHEMA.TABLE_A") == [ "DB.SCHEMA.TABLE_A" ] # No database - assert snowflake_retriever._extract_table_names(query="SELECT * FROM SCHEMA.TABLE_A") == ["SCHEMA.TABLE_A"] + assert snowflake_table_retriever._extract_table_names(query="SELECT * FROM SCHEMA.TABLE_A") == [ + "SCHEMA.TABLE_A" + ] - def test_extract_multiple_table_names(self, snowflake_retriever: SnowflakeRetriever) -> None: + def test_extract_multiple_table_names(self, snowflake_table_retriever: SnowflakeTableRetriever) -> None: queries = [ "MERGE INTO table_a USING table_b ON table_a.id = table_b.id WHEN MATCHED", "SELECT a.name, b.value FROM table_a AS a FULL OUTER JOIN table_b AS b ON a.id = b.id", "SELECT a.name, b.value FROM table_a AS a RIGHT JOIN table_b AS b ON a.id = b.id", ] for query in queries: - result = snowflake_retriever._extract_table_names(query) + result = snowflake_table_retriever._extract_table_names(query) # Due to using set when deduplicating assert result == ["TABLE_A", "TABLE_B"] or ["TABLE_B", "TABLE_A"] - def test_extract_multiple_db_schema_from_table_names(self, snowflake_retriever: SnowflakeRetriever) -> None: + def test_extract_multiple_db_schema_from_table_names( + self, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: assert ( - snowflake_retriever._extract_table_names( + snowflake_table_retriever._extract_table_names( query="""SELECT a.name, b.value FROM DB.SCHEMA.TABLE_A AS a FULL OUTER JOIN DATABASE.SCHEMA.TABLE_b AS b ON a.id = b.id""" ) @@ -180,7 +203,7 @@ def test_extract_multiple_db_schema_from_table_names(self, snowflake_retriever: ) # No database assert ( - snowflake_retriever._extract_table_names( + snowflake_table_retriever._extract_table_names( query="""SELECT a.name, b.value FROM SCHEMA.TABLE_A AS a FULL OUTER JOIN SCHEMA.TABLE_b AS b ON a.id = b.id""" ) @@ -188,8 +211,12 @@ def test_extract_multiple_db_schema_from_table_names(self, snowflake_retriever: or ["SCHEMA.TABLE_A", "SCHEMA.TABLE_B"] ) - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") - def test_execute_sql_query(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_execute_sql_query( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: mock_conn = MagicMock() mock_cursor = MagicMock() mock_col1 = MagicMock() @@ -203,15 +230,17 @@ def test_execute_sql_query(self, mock_connect: MagicMock, snowflake_retriever: S query = "SELECT * FROM test_table" expected = pd.DataFrame(data={"City": ["Chicago"], "State": ["Illinois"]}) - result = snowflake_retriever._execute_sql_query(conn=mock_conn, query=query) + result = snowflake_table_retriever._execute_sql_query(conn=mock_conn, query=query) mock_cursor.execute.assert_called_once_with(query) assert result.equals(expected) - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) def test_is_select_only( - self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture ) -> None: mock_conn = MagicMock() mock_cursor = MagicMock() @@ -234,7 +263,7 @@ def test_is_select_only( ] query = "select * from locations" - result = snowflake_retriever._check_privilege(conn=mock_conn, user="test_user", query=query) + result = snowflake_table_retriever._check_privilege(conn=mock_conn, user="test_user", query=query) assert result mock_cursor.fetchall.side_effect = [ @@ -253,13 +282,16 @@ def test_is_select_only( ], ] - result = snowflake_retriever._check_privilege(conn=mock_conn, user="test_user", query=query) + result = snowflake_table_retriever._check_privilege(conn=mock_conn, user="test_user", query=query) assert not result - assert "User does not have `Select` privilege on the table" in caplog.text - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") - def test_column_after_from(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_column_after_from( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: mock_conn = MagicMock() mock_cursor = MagicMock() mock_col1 = MagicMock() @@ -273,13 +305,15 @@ def test_column_after_from(self, mock_connect: MagicMock, snowflake_retriever: S query = "SELECT id, extract(year from date_col) as year FROM test_table" expected = pd.DataFrame(data={"id": [1233], "year": [1998]}) - result = snowflake_retriever._execute_sql_query(conn=mock_conn, query=query) + result = snowflake_table_retriever._execute_sql_query(conn=mock_conn, query=query) mock_cursor.execute.assert_called_once_with(query) assert result.equals(expected) - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") - def test_run(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_run(self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever) -> None: mock_conn = MagicMock() mock_cursor = MagicMock() mock_col1 = MagicMock() @@ -314,7 +348,7 @@ def test_run(self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetrie "table": "| City | State |\n|:--------|:---------|\n| Chicago | Illinois |", } - result = snowflake_retriever.run(query=query) + result = snowflake_table_retriever.run(query=query) assert result["dataframe"].equals(expected["dataframe"]) assert result["table"] == expected["table"] @@ -327,7 +361,7 @@ def mock_chat_completion(self) -> Generator: with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: completion = ChatCompletion( id="foo", - model="gpt-4o", + model="gpt-4o-mini", object="chat.completion", choices=[ Choice( @@ -344,9 +378,14 @@ def mock_chat_completion(self) -> Generator: mock_chat_completion_create.return_value = completion yield mock_chat_completion_create - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) def test_run_pipeline( - self, mock_connect: MagicMock, mock_chat_completion: MagicMock, snowflake_retriever: SnowflakeRetriever + self, + mock_connect: MagicMock, + mock_chat_completion: MagicMock, + snowflake_table_retriever: SnowflakeTableRetriever, ) -> None: mock_conn = MagicMock() mock_cursor = MagicMock() @@ -385,7 +424,7 @@ def test_run_pipeline( pipeline.add_component("llm", llm) pipeline.add_component("adapter", adapter) - pipeline.add_component("snowflake", snowflake_retriever) + pipeline.add_component("snowflake", snowflake_table_retriever) pipeline.connect(sender="llm.replies", receiver="adapter.replies") pipeline.connect(sender="adapter.output", receiver="snowflake.query") @@ -398,7 +437,8 @@ def test_run_pipeline( def test_from_dict(self, monkeypatch: MagicMock) -> None: monkeypatch.setenv("SNOWFLAKE_API_KEY", "test-api-key") data = { - "type": "haystack_integrations.components.retrievers.snowflake.snowflake_retriever.SnowflakeRetriever", + "type": "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever" + ".SnowflakeTableRetriever", "init_parameters": { "api_key": { "env_vars": ["SNOWFLAKE_API_KEY"], @@ -413,7 +453,7 @@ def test_from_dict(self, monkeypatch: MagicMock) -> None: "login_timeout": 3, }, } - component = SnowflakeRetriever.from_dict(data) + component = SnowflakeTableRetriever.from_dict(data) assert component.user == "test_user" assert component.account == "new_account" @@ -425,7 +465,7 @@ def test_from_dict(self, monkeypatch: MagicMock) -> None: def test_to_dict_default(self, monkeypatch: MagicMock) -> None: monkeypatch.setenv("SNOWFLAKE_API_KEY", "test-api-key") - component = SnowflakeRetriever( + component = SnowflakeTableRetriever( user="test_user", api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), account="test_account", @@ -438,7 +478,7 @@ def test_to_dict_default(self, monkeypatch: MagicMock) -> None: data = component.to_dict() assert data == { - "type": "haystack_integrations.components.retrievers.snowflake.snowflake_retriever.SnowflakeRetriever", + "type": "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.SnowflakeTableRetriever", "init_parameters": { "api_key": { "env_vars": ["SNOWFLAKE_API_KEY"], @@ -457,7 +497,7 @@ def test_to_dict_default(self, monkeypatch: MagicMock) -> None: def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None: monkeypatch.setenv("SNOWFLAKE_API_KEY", "test-api-key") monkeypatch.setenv("SNOWFLAKE_API_KEY", "test-api-key") - component = SnowflakeRetriever( + component = SnowflakeTableRetriever( user="John", api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), account="TGMD-EEREW", @@ -470,7 +510,7 @@ def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None: data = component.to_dict() assert data == { - "type": "haystack_integrations.components.retrievers.snowflake.snowflake_retriever.SnowflakeRetriever", + "type": "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.SnowflakeTableRetriever", "init_parameters": { "api_key": { "env_vars": ["SNOWFLAKE_API_KEY"], @@ -486,8 +526,12 @@ def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None: }, } - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") - def test_has_select_privilege(self, mock_logger: MagicMock, snowflake_retriever: SnowflakeRetriever) -> None: + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_has_select_privilege( + self, mock_logger: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: # Define test cases test_cases = [ # Test case 1: Fully qualified table name in query @@ -523,15 +567,17 @@ def test_has_select_privilege(self, mock_logger: MagicMock, snowflake_retriever: ] for case in test_cases: - result = snowflake_retriever._has_select_privilege( + result = snowflake_table_retriever._has_select_privilege( privileges=case["privileges"], # type: ignore table_name=case["table_name"], # type: ignore ) assert result == case["expected_result"] # type: ignore - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) def test_user_does_not_exist( - self, mock_connect: MagicMock, snowflake_retriever: SnowflakeRetriever, caplog: LogCaptureFixture + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture ) -> None: mock_conn = MagicMock() mock_connect.return_value = mock_conn @@ -539,13 +585,27 @@ def test_user_does_not_exist( mock_cursor = mock_conn.cursor.return_value mock_cursor.fetchall.return_value = [] - result = snowflake_retriever._fetch_data(query="""SELECT * FROM test_table""") + result = snowflake_table_retriever._fetch_data(query="""SELECT * FROM test_table""") assert result.empty assert "User does not exist" in caplog.text - @patch("haystack_integrations.components.retrievers.snowflake.snowflake_retriever.snowflake.connector.connect") - def test_empty_query(self, snowflake_retriever: SnowflakeRetriever) -> None: - result = snowflake_retriever._fetch_data(query="") + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_empty_query(self, snowflake_table_retriever: SnowflakeTableRetriever) -> None: + result = snowflake_table_retriever._fetch_data(query="") assert result.empty + + def test_serialization_deserialization_pipeline(self) -> None: + + pipeline = Pipeline() + pipeline.add_component("snow", SnowflakeTableRetriever(user="test_user", account="test_account")) + pipeline.add_component("prompt_builder", PromptBuilder(template="Display results {{ table }}")) + pipeline.connect("snow.table", "prompt_builder.table") + + pipeline_dict = pipeline.to_dict() + + new_pipeline = Pipeline.from_dict(pipeline_dict) + assert new_pipeline == pipeline