Skip to content

Commit

Permalink
added vertica profile with test
Browse files Browse the repository at this point in the history
Signed-off-by: Perttu Salonen <[email protected]>
  • Loading branch information
perttus committed Sep 15, 2023
1 parent 27d1945 commit 9af20f2
Show file tree
Hide file tree
Showing 5 changed files with 273 additions and 2 deletions.
3 changes: 3 additions & 0 deletions cosmos/profiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .trino.certificate import TrinoCertificateProfileMapping
from .trino.jwt import TrinoJWTProfileMapping
from .trino.ldap import TrinoLDAPProfileMapping
from .vertica.user_pass import VerticaUserPasswordProfileMapping

profile_mappings: list[Type[BaseProfileMapping]] = [
GoogleCloudServiceAccountFileProfileMapping,
Expand All @@ -34,6 +35,7 @@
TrinoLDAPProfileMapping,
TrinoCertificateProfileMapping,
TrinoJWTProfileMapping,
VerticaUserPasswordProfileMapping,
]


Expand Down Expand Up @@ -70,4 +72,5 @@ def get_automatic_profile_mapping(
"TrinoLDAPProfileMapping",
"TrinoCertificateProfileMapping",
"TrinoJWTProfileMapping",
"VerticaUserPasswordProfileMapping",
]
5 changes: 5 additions & 0 deletions cosmos/profiles/vertica/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"Vertica Airflow connection -> dbt profile mappings"

from .user_pass import VerticaUserPasswordProfileMapping

__all__ = ["VerticaUserPasswordProfileMapping"]
76 changes: 76 additions & 0 deletions cosmos/profiles/vertica/user_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"Maps Airflow Vertica connections using user + password authentication to dbt profiles."
from __future__ import annotations

from typing import Any

from ..base import BaseProfileMapping


class VerticaUserPasswordProfileMapping(BaseProfileMapping):
"""
Maps Airflow Vertica connections using user + password authentication to dbt profiles.
https://docs.getdbt.com/reference/warehouse-setups/vertica-setup
https://airflow.apache.org/docs/apache-airflow-providers-vertica/stable/connections/vertica.html
"""

airflow_connection_type: str = "vertica"
dbt_profile_type: str = "vertica"

required_fields = [
"host",
"user",
"password",
"database",
"schema",
]
secret_fields = [
"password",
]
airflow_param_mapping = {
"host": "host",
"user": "login",
"password": "password",
"port": "port",
"schema": "schema",
"database": "extra.database",
"autocommit": "extra.autocommit",
"backup_server_node": "extra.backup_server_node",
"binary_transfer": "extra.binary_transfer",
"connection_load_balance": "extra.connection_load_balance",
"connection_timeout": "extra.connection_timeout",
"disable_copy_local": "extra.disable_copy_local",
"kerberos_host_name": "extra.kerberos_host_name",
"kerberos_service_name": "extra.kerberos_service_name",
"log_level": "extra.log_level",
"log_path": "extra.log_path",
"oauth_access_token": "extra.oauth_access_token",
"request_complex_types": "extra.request_complex_types",
"session_label": "extra.session_label",
"ssl": "extra.ssl",
"unicode_error": "extra.unicode_error",
"use_prepared_statements": "extra.use_prepared_statements",
"workload": "extra.workload",
}

@property
def profile(self) -> dict[str, Any | None]:
"Gets profile. The password is stored in an environment variable."
profile = {
"port": 5433,
**self.mapped_params,
**self.profile_args,
# password should always get set as env var
"password": self.get_env_var_format("password"),
}

return self.filter_null(profile)

@property
def mock_profile(self) -> dict[str, Any | None]:
"Gets mock profile. Defaults port to 5433."
parent_mock = super().mock_profile

return {
"port": 5433,
**parent_mock,
}
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ dbt-all = [
"dbt-redshift<=1.5.4",
"dbt-snowflake<=1.5.4",
"dbt-spark<=1.5.4",
"dbt-vertica<=1.5.4",
]
dbt-bigquery = [
"dbt-bigquery<=1.5.4",
Expand All @@ -75,6 +76,9 @@ dbt-snowflake = [
dbt-spark = [
"dbt-spark<=1.5.4",
]
dbt-vertica = [
"dbt-vertica<=1.5.4",
]
openlineage = [
"openlineage-integration-common",
"openlineage-airflow",
Expand Down Expand Up @@ -160,10 +164,10 @@ test = 'pytest -vv --durations=0 . -m "not integration" --ignore=tests/test_exam
test-cov = """pytest -vv --cov=cosmos --cov-report=term-missing --cov-report=xml --durations=0 -m "not integration" --ignore=tests/test_example_dags.py --ignore=tests/test_example_dags_no_connections.py"""
# we install using the following workaround to overcome installation conflicts, such as:
# apache-airflow 2.3.0 and dbt-core [0.13.0 - 1.5.2] and jinja2>=3.0.0 because these package versions have conflicting dependencies
test-integration-setup = """pip uninstall dbt-postgres dbt-databricks; \
test-integration-setup = """pip uninstall dbt-postgres dbt-databricks dbt-vertica; \
rm -rf airflow.*; \
airflow db init; \
pip install 'dbt-core==1.5.4' 'dbt-databricks<=1.5.4' 'dbt-postgres<=1.5.4' 'openlineage-airflow'"""
pip install 'dbt-core==1.5.4' 'dbt-databricks<=1.5.4' 'dbt-postgres<=1.5.4' 'dbt-vertica<=1.5.4' 'openlineage-airflow'"""
test-integration = """rm -rf dbt/jaffle_shop/dbt_packages;
pytest -vv \
--cov=cosmos \
Expand Down
183 changes: 183 additions & 0 deletions tests/profiles/vertica/test_vertica_user_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
"Tests for the vertica profile."

from unittest.mock import patch

import pytest
from airflow.models.connection import Connection

from cosmos.profiles import get_automatic_profile_mapping
from cosmos.profiles.vertica.user_pass import (
VerticaUserPasswordProfileMapping,
)


@pytest.fixture()
def mock_vertica_conn(): # type: ignore
"""
Sets the connection as an environment variable.
"""
conn = Connection(
conn_id="my_vertica_connection",
conn_type="vertica",
host="my_host",
login="my_user",
password="my_password",
port=5432,
schema="my_schema",
extra='{"database": "my_database"}',
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


@pytest.fixture()
def mock_vertica_conn_custom_port(): # type: ignore
"""
Sets the connection as an environment variable.
"""
conn = Connection(
conn_id="my_vertica_connection",
conn_type="vertica",
host="my_host",
login="my_user",
password="my_password",
port=7472,
schema="my_schema",
extra='{"database": "my_database"}',
)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
yield conn


def test_connection_claiming() -> None:
"""
Tests that the vertica profile mapping claims the correct connection type.
"""
# should only claim when:
# - conn_type == vertica
# and the following exist:
# - host
# - user
# - password
# - port
# - database or database
# - schema
potential_values = {
"conn_type": "vertica",
"host": "my_host",
"login": "my_user",
"password": "my_password",
"schema": "my_schema",
"extra": '{"database": "my_database"}',
}

# if we're missing any of the values, it shouldn't claim
for key in potential_values:
values = potential_values.copy()
del values[key]
conn = Connection(**values) # type: ignore

print("testing with", values)

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = VerticaUserPasswordProfileMapping(conn)
assert not profile_mapping.can_claim_connection()

# also test when there's no database
conn = Connection(**potential_values) # type: ignore
conn.extra = ''
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = VerticaUserPasswordProfileMapping(conn)
assert not profile_mapping.can_claim_connection()

# if we have them all, it should claim
conn = Connection(**potential_values) # type: ignore
with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = VerticaUserPasswordProfileMapping(conn)
assert profile_mapping.can_claim_connection()


def test_profile_mapping_selected(
mock_vertica_conn: Connection,
) -> None:
"""
Tests that the correct profile mapping is selected.
"""
profile_mapping = get_automatic_profile_mapping(
mock_vertica_conn.conn_id,
{"schema": "my_schema"},
)
assert isinstance(profile_mapping, VerticaUserPasswordProfileMapping)


def test_profile_mapping_keeps_custom_port(mock_vertica_conn_custom_port: Connection) -> None:
profile = VerticaUserPasswordProfileMapping(mock_vertica_conn_custom_port.conn_id, {"schema": "my_schema"})
assert profile.profile["port"] == 7472


def test_profile_args(
mock_vertica_conn: Connection,
) -> None:
"""
Tests that the profile values get set correctly.
"""
profile_mapping = get_automatic_profile_mapping(
mock_vertica_conn.conn_id,
profile_args={"schema": "my_schema"},
)
assert profile_mapping.profile_args == {
"schema": "my_schema",
}

assert profile_mapping.profile == {
"type": mock_vertica_conn.conn_type,
"host": mock_vertica_conn.host,
"user": mock_vertica_conn.login,
"password": "{{ env_var('COSMOS_CONN_VERTICA_PASSWORD') }}",
"port": mock_vertica_conn.port,
"schema": "my_schema",
"database": mock_vertica_conn.extra_dejson.get("database"),
}


def test_profile_args_overrides(
mock_vertica_conn: Connection,
) -> None:
"""
Tests that you can override the profile values.
"""
profile_mapping = get_automatic_profile_mapping(
mock_vertica_conn.conn_id,
profile_args={"schema": "my_schema", "database": "my_db_override"},
)
assert profile_mapping.profile_args == {
"schema": "my_schema",
"database": "my_db_override",
}

assert profile_mapping.profile == {
"type": mock_vertica_conn.conn_type,
"host": mock_vertica_conn.host,
"user": mock_vertica_conn.login,
"password": "{{ env_var('COSMOS_CONN_VERTICA_PASSWORD') }}",
"port": mock_vertica_conn.port,
"database": "my_db_override",
"schema": "my_schema",
}


def test_profile_env_vars(
mock_vertica_conn: Connection,
) -> None:
"""
Tests that the environment variables get set correctly.
"""
profile_mapping = get_automatic_profile_mapping(
mock_vertica_conn.conn_id,
profile_args={"schema": "my_schema"},
)
assert profile_mapping.env_vars == {
"COSMOS_CONN_VERTICA_PASSWORD": mock_vertica_conn.password,
}

0 comments on commit 9af20f2

Please sign in to comment.