Skip to content

Commit

Permalink
Fix Snowflake Profile mapping when using AWS default region (#1406)
Browse files Browse the repository at this point in the history
When using a Cosmos Snowflake Profile mapping using a Snowflake account
set in the AWS default region, Cosmos would fail if the default region
was specified in the Airflow connection.

The dbt docs state:
> For AWS accounts in the US West default region, you can use abc123
(without any other segments). For some AWS accounts you will have to
append the region and/or cloud platform. For example, abc123.eu-west-1
or abc123.eu-west-2.aws.

https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#account
 
Although it seems that defining the default region would be optional, a
Cosmos user reported facing 404 and seeing a dbt error message when
attempting to use `SnowflakeUserPasswordProfileMapping` with an Airflow
Snowflake connection that defined the region `us-west-2`.

![snowflake-404](https://github.com/user-attachments/assets/c1884fff-1cad-4c57-b2f3-11a4f44b085b)

We solved the issue by removing the region `us-west-2` from the
connection.

Since this restriction only applies to AWS and this Snowflake region
only exists to AWS, this change seems safe:
![Screenshot 2024-12-18 at 18 45
31](https://github.com/user-attachments/assets/ff2f8a0b-578b-4a62-9fc3-258a43148775)
  • Loading branch information
tatiana authored Dec 19, 2024
1 parent a4f50de commit 22b20f1
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 72 deletions.
28 changes: 28 additions & 0 deletions cosmos/profiles/snowflake/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import annotations

from typing import Any

from cosmos.profiles.base import BaseProfileMapping

DEFAULT_AWS_REGION = "us-west-2"


class SnowflakeBaseProfileMapping(BaseProfileMapping):

@property
def profile(self) -> dict[str, Any | None]:
"""Gets profile."""
profile_vars = {
**self.mapped_params,
**self.profile_args,
}
return profile_vars

def transform_account(self, account: str) -> str:
"""Transform the account to the format <account>.<region> if it's not already."""
region = self.conn.extra_dejson.get("region")
#
if region and region != DEFAULT_AWS_REGION and region not in account:
account = f"{account}.{region}"

return str(account)
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import json
from typing import TYPE_CHECKING, Any

from ..base import BaseProfileMapping
from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping

if TYPE_CHECKING:
from airflow.models import Connection


class SnowflakeEncryptedPrivateKeyPemProfileMapping(BaseProfileMapping):
class SnowflakeEncryptedPrivateKeyPemProfileMapping(SnowflakeBaseProfileMapping):
"""
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key.
https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication
Expand Down Expand Up @@ -75,20 +75,7 @@ def conn(self) -> Connection:
@property
def profile(self) -> dict[str, Any | None]:
"""Gets profile."""
profile_vars = {
**self.mapped_params,
**self.profile_args,
"private_key": self.get_env_var_format("private_key"),
"private_key_passphrase": self.get_env_var_format("private_key_passphrase"),
}

# remove any null values
profile_vars = super().profile
profile_vars["private_key"] = self.get_env_var_format("private_key")
profile_vars["private_key_passphrase"] = self.get_env_var_format("private_key_passphrase")
return self.filter_null(profile_vars)

def transform_account(self, account: str) -> str:
"""Transform the account to the format <account>.<region> if it's not already."""
region = self.conn.extra_dejson.get("region")
if region and region not in account:
account = f"{account}.{region}"

return str(account)
22 changes: 4 additions & 18 deletions cosmos/profiles/snowflake/user_encrypted_privatekey_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import json
from typing import TYPE_CHECKING, Any

from ..base import BaseProfileMapping
from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping

if TYPE_CHECKING:
from airflow.models import Connection


class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(BaseProfileMapping):
class SnowflakeEncryptedPrivateKeyFilePemProfileMapping(SnowflakeBaseProfileMapping):
"""
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key path.
https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication
Expand Down Expand Up @@ -74,20 +74,6 @@ def conn(self) -> Connection:
@property
def profile(self) -> dict[str, Any | None]:
"""Gets profile."""
profile_vars = {
**self.mapped_params,
**self.profile_args,
# private_key_passphrase should always get set as env var
"private_key_passphrase": self.get_env_var_format("private_key_passphrase"),
}

# remove any null values
profile_vars = super().profile
profile_vars["private_key_passphrase"] = self.get_env_var_format("private_key_passphrase")
return self.filter_null(profile_vars)

def transform_account(self, account: str) -> str:
"""Transform the account to the format <account>.<region> if it's not already."""
region = self.conn.extra_dejson.get("region")
if region and region not in account:
account = f"{account}.{region}"

return str(account)
23 changes: 5 additions & 18 deletions cosmos/profiles/snowflake/user_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import json
from typing import TYPE_CHECKING, Any

from ..base import BaseProfileMapping
from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping

if TYPE_CHECKING:
from airflow.models import Connection


class SnowflakeUserPasswordProfileMapping(BaseProfileMapping):
class SnowflakeUserPasswordProfileMapping(SnowflakeBaseProfileMapping):
"""
Maps Airflow Snowflake connections to dbt profiles if they use a user/password.
https://docs.getdbt.com/reference/warehouse-setups/snowflake-setup
Expand Down Expand Up @@ -76,20 +76,7 @@ def conn(self) -> Connection:
@property
def profile(self) -> dict[str, Any | None]:
"""Gets profile."""
profile_vars = {
**self.mapped_params,
**self.profile_args,
# password should always get set as env var
"password": self.get_env_var_format("password"),
}

# remove any null values
profile_vars = super().profile
# password should always get set as env var
profile_vars["password"] = self.get_env_var_format("password")
return self.filter_null(profile_vars)

def transform_account(self, account: str) -> str:
"""Transform the account to the format <account>.<region> if it's not already."""
region = self.conn.extra_dejson.get("region")
if region and region not in account:
account = f"{account}.{region}"

return str(account)
23 changes: 5 additions & 18 deletions cosmos/profiles/snowflake/user_privatekey.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import json
from typing import TYPE_CHECKING, Any

from ..base import BaseProfileMapping
from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping

if TYPE_CHECKING:
from airflow.models import Connection


class SnowflakePrivateKeyPemProfileMapping(BaseProfileMapping):
class SnowflakePrivateKeyPemProfileMapping(SnowflakeBaseProfileMapping):
"""
Maps Airflow Snowflake connections to dbt profiles if they use a user/private key.
https://docs.getdbt.com/docs/core/connect-data-platform/snowflake-setup#key-pair-authentication
Expand Down Expand Up @@ -65,20 +65,7 @@ def conn(self) -> Connection:
@property
def profile(self) -> dict[str, Any | None]:
"""Gets profile."""
profile_vars = {
**self.mapped_params,
**self.profile_args,
# private_key should always get set as env var
"private_key": self.get_env_var_format("private_key"),
}

# remove any null values
profile_vars = super().profile
# private_key should always get set as env var
profile_vars["private_key"] = self.get_env_var_format("private_key")
return self.filter_null(profile_vars)

def transform_account(self, account: str) -> str:
"""Transform the account to the format <account>.<region> if it's not already."""
region = self.conn.extra_dejson.get("region")
if region and region not in account:
account = f"{account}.{region}"

return str(account)
19 changes: 19 additions & 0 deletions tests/profiles/snowflake/test_snowflake_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from unittest.mock import patch

from cosmos.profiles.snowflake.base import SnowflakeBaseProfileMapping


@patch("cosmos.profiles.snowflake.base.SnowflakeBaseProfileMapping.conn.extra_dejson", {"region": "us-west-2"})
@patch("cosmos.profiles.snowflake.base.SnowflakeBaseProfileMapping.conn")
def test_default_region(mock_conn):
profile_mapping = SnowflakeBaseProfileMapping(conn_id="fake-conn")
response = profile_mapping.transform_account("myaccount")
assert response == "myaccount"


@patch("cosmos.profiles.snowflake.base.SnowflakeBaseProfileMapping.conn.extra_dejson", {"region": "us-east-1"})
@patch("cosmos.profiles.snowflake.base.SnowflakeBaseProfileMapping.conn")
def test_non_default_region(mock_conn):
profile_mapping = SnowflakeBaseProfileMapping(conn_id="fake-conn")
response = profile_mapping.transform_account("myaccount")
assert response == "myaccount.us-east-1"

0 comments on commit 22b20f1

Please sign in to comment.