From 25a79926fb525f77ef8cd5797b95f98b28223578 Mon Sep 17 00:00:00 2001 From: Isaac Ireland Date: Thu, 14 Nov 2024 10:01:19 +0000 Subject: [PATCH] adding privatekey auth param --- .../snowflake/snowflake_table_retriever.py | 27 +++++++++++++++---- .../tests/test_snowflake_table_retriever.py | 6 +++++ 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py index aa6f5ff4d..083060ec3 100644 --- a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py +++ b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py @@ -69,6 +69,8 @@ def __init__( user: str, account: str, api_key: Secret = Secret.from_env_var("SNOWFLAKE_API_KEY"), # noqa: B008 + private_key_file: Optional[str] = None, + private_key_file_pwd: Optional[str] = None, database: Optional[str] = None, db_schema: Optional[str] = None, warehouse: Optional[str] = None, @@ -82,11 +84,16 @@ def __init__( :param db_schema: Name of the schema to use. :param warehouse: Name of the warehouse to use. :param login_timeout: Timeout in seconds for login. By default, 60 seconds. + :param private_key_file: Location of private key + mutually exclusive to password, if key_file is provided this auth method will be used. + :param private_key_file_pwd: Password for private key file """ self.user = user self.account = account self.api_key = api_key + self.private_key_file = private_key_file + self.private_key_file_pwd = private_key_file_pwd self.database = database self.db_schema = db_schema self.warehouse = warehouse @@ -104,6 +111,8 @@ def to_dict(self) -> Dict[str, Any]: user=self.user, account=self.account, api_key=self.api_key.to_dict(), + private_key_file=self.private_key_file, + private_key_file_pwd=self.private_key_file_pwd, database=self.database, db_schema=self.db_schema, warehouse=self.warehouse, @@ -275,17 +284,25 @@ def _fetch_data( if not query: return df try: - # Create a new connection with every run - conn = self._snowflake_connector( - connect_params={ + # Build up param connection + connect_params={ "user": self.user, "account": self.account, - "password": self.api_key.resolve_value(), + "private_key_file": self.private_key_file, + "private_key_file_pwd": self.private_key_file_pwd, "database": self.database, "schema": self.db_schema, "warehouse": self.warehouse, - "login_timeout": self.login_timeout, + "login_timeout": self.login_timeout } + + # Check if private key has been provided + if self.private_key_file is None: + connect_params["password"] = self.api_key.resolve_value() + + # Create a new connection with every run + conn = self._snowflake_connector( + connect_params=connect_params ) if conn is None: return df diff --git a/integrations/snowflake/tests/test_snowflake_table_retriever.py b/integrations/snowflake/tests/test_snowflake_table_retriever.py index f5b8fee37..4aa990a6c 100644 --- a/integrations/snowflake/tests/test_snowflake_table_retriever.py +++ b/integrations/snowflake/tests/test_snowflake_table_retriever.py @@ -473,6 +473,8 @@ def test_to_dict_default(self, monkeypatch: MagicMock) -> None: db_schema="test_schema", warehouse="test_warehouse", login_timeout=30, + private_key_file=None, + private_key_file_pwd=None ) data = component.to_dict() @@ -494,6 +496,8 @@ def test_to_dict_default(self, monkeypatch: MagicMock) -> None: "db_schema": "test_schema", "warehouse": "test_warehouse", "login_timeout": 30, + "private_key_file": None, + "private_key_file_pwd": None }, } @@ -529,6 +533,8 @@ def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None: "db_schema": "SMALL_TOWNS", "warehouse": "COMPUTE_WH", "login_timeout": 30, + "private_key_file": None, + "private_key_file_pwd": None }, }