diff --git a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py index aa6f5ff4d..f824ac265 100644 --- a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py +++ b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py @@ -69,6 +69,8 @@ def __init__( user: str, account: str, api_key: Secret = Secret.from_env_var("SNOWFLAKE_API_KEY"), # noqa: B008 + private_key_file: Optional[str] = None, + private_key_file_pwd: Optional[str] = None, database: Optional[str] = None, db_schema: Optional[str] = None, warehouse: Optional[str] = None, @@ -82,11 +84,18 @@ def __init__( :param db_schema: Name of the schema to use. :param warehouse: Name of the warehouse to use. :param login_timeout: Timeout in seconds for login. By default, 60 seconds. + :param private_key_file: Location of private key-pair file. + This is mutually exclusive to password, if key_file is provided this auth method will be used. + See Snowflake documentation on Key-Pair authentication for further information: + https://docs.snowflake.com/en/user-guide/key-pair-auth + :param private_key_file_pwd: Password for private key file """ self.user = user self.account = account self.api_key = api_key + self.private_key_file = private_key_file + self.private_key_file_pwd = private_key_file_pwd self.database = database self.db_schema = db_schema self.warehouse = warehouse @@ -104,6 +113,8 @@ def to_dict(self) -> Dict[str, Any]: user=self.user, account=self.account, api_key=self.api_key.to_dict(), + private_key_file=self.private_key_file, + private_key_file_pwd=self.private_key_file_pwd, database=self.database, db_schema=self.db_schema, warehouse=self.warehouse, @@ -275,18 +286,22 @@ def _fetch_data( if not query: return df try: + # Build up param connection + connect_params = { + "user": self.user, + "account": self.account, + "private_key_file": self.private_key_file, + "private_key_file_pwd": self.private_key_file_pwd, + "database": self.database, + "schema": self.db_schema, + "warehouse": self.warehouse, + "login_timeout": self.login_timeout, + } + # Check if private key has been provided + if self.private_key_file is None: + connect_params["password"] = self.api_key.resolve_value() # Create a new connection with every run - conn = self._snowflake_connector( - connect_params={ - "user": self.user, - "account": self.account, - "password": self.api_key.resolve_value(), - "database": self.database, - "schema": self.db_schema, - "warehouse": self.warehouse, - "login_timeout": self.login_timeout, - } - ) + conn = self._snowflake_connector(connect_params=connect_params) if conn is None: return df except (ForbiddenError, ProgrammingError) as e: diff --git a/integrations/snowflake/tests/test_snowflake_table_retriever.py b/integrations/snowflake/tests/test_snowflake_table_retriever.py index f5b8fee37..5991ba9da 100644 --- a/integrations/snowflake/tests/test_snowflake_table_retriever.py +++ b/integrations/snowflake/tests/test_snowflake_table_retriever.py @@ -473,6 +473,8 @@ def test_to_dict_default(self, monkeypatch: MagicMock) -> None: db_schema="test_schema", warehouse="test_warehouse", login_timeout=30, + private_key_file=None, + private_key_file_pwd=None, ) data = component.to_dict() @@ -494,6 +496,8 @@ def test_to_dict_default(self, monkeypatch: MagicMock) -> None: "db_schema": "test_schema", "warehouse": "test_warehouse", "login_timeout": 30, + "private_key_file": None, + "private_key_file_pwd": None, }, } @@ -529,6 +533,8 @@ def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None: "db_schema": "SMALL_TOWNS", "warehouse": "COMPUTE_WH", "login_timeout": 30, + "private_key_file": None, + "private_key_file_pwd": None, }, }