Skip to content

Commit

Permalink
adding privatekey auth param
Browse files Browse the repository at this point in the history
  • Loading branch information
iireland-ii committed Nov 14, 2024
1 parent 025a05a commit 25a7992
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def __init__(
user: str,
account: str,
api_key: Secret = Secret.from_env_var("SNOWFLAKE_API_KEY"), # noqa: B008
private_key_file: Optional[str] = None,
private_key_file_pwd: Optional[str] = None,
database: Optional[str] = None,
db_schema: Optional[str] = None,
warehouse: Optional[str] = None,
Expand All @@ -82,11 +84,16 @@ def __init__(
:param db_schema: Name of the schema to use.
:param warehouse: Name of the warehouse to use.
:param login_timeout: Timeout in seconds for login. By default, 60 seconds.
:param private_key_file: Location of private key
mutually exclusive to password, if key_file is provided this auth method will be used.
:param private_key_file_pwd: Password for private key file
"""

self.user = user
self.account = account
self.api_key = api_key
self.private_key_file = private_key_file
self.private_key_file_pwd = private_key_file_pwd
self.database = database
self.db_schema = db_schema
self.warehouse = warehouse
Expand All @@ -104,6 +111,8 @@ def to_dict(self) -> Dict[str, Any]:
user=self.user,
account=self.account,
api_key=self.api_key.to_dict(),
private_key_file=self.private_key_file,
private_key_file_pwd=self.private_key_file_pwd,
database=self.database,
db_schema=self.db_schema,
warehouse=self.warehouse,
Expand Down Expand Up @@ -275,17 +284,25 @@ def _fetch_data(
if not query:
return df
try:
# Create a new connection with every run
conn = self._snowflake_connector(
connect_params={
# Build up param connection
connect_params={
"user": self.user,
"account": self.account,
"password": self.api_key.resolve_value(),
"private_key_file": self.private_key_file,
"private_key_file_pwd": self.private_key_file_pwd,
"database": self.database,
"schema": self.db_schema,
"warehouse": self.warehouse,
"login_timeout": self.login_timeout,
"login_timeout": self.login_timeout
}

# Check if private key has been provided
if self.private_key_file is None:
connect_params["password"] = self.api_key.resolve_value()

# Create a new connection with every run
conn = self._snowflake_connector(
connect_params=connect_params
)
if conn is None:
return df
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,8 @@ def test_to_dict_default(self, monkeypatch: MagicMock) -> None:
db_schema="test_schema",
warehouse="test_warehouse",
login_timeout=30,
private_key_file=None,
private_key_file_pwd=None
)

data = component.to_dict()
Expand All @@ -494,6 +496,8 @@ def test_to_dict_default(self, monkeypatch: MagicMock) -> None:
"db_schema": "test_schema",
"warehouse": "test_warehouse",
"login_timeout": 30,
"private_key_file": None,
"private_key_file_pwd": None
},
}

Expand Down Expand Up @@ -529,6 +533,8 @@ def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None:
"db_schema": "SMALL_TOWNS",
"warehouse": "COMPUTE_WH",
"login_timeout": 30,
"private_key_file": None,
"private_key_file_pwd": None
},
}

Expand Down

0 comments on commit 25a7992

Please sign in to comment.