Skip to content

Commit

Permalink
Increase test coverage and fix implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
tatiana committed Apr 18, 2024
1 parent 9325fac commit 5c010cc
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 245 deletions.
10 changes: 6 additions & 4 deletions cosmos/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
from cosmos.dbt.project import get_partial_parse_path


# It was considered to create a cache identifier based on the dbt project path, as opposed
# to where it is used in Airflow. However, we could have concurrency issues if the same
# dbt cached directory was being used by different dbt task groups or DAGs within the same
# node. For this reason, as a starting point, the cache is identified by where it is used.
# This can be reviewed in the future.
def create_cache_identifier(dag: DAG, task_group: TaskGroup | None) -> str:
"""
Given a DAG name and a (optional) task_group_name, create the identifier for caching.
Expand All @@ -22,11 +27,8 @@ def create_cache_identifier(dag: DAG, task_group: TaskGroup | None) -> str:
if task_group:
if task_group.dag_id is not None:
cache_identifiers_list = [task_group.dag_id]
if task_group.upstream_group_ids is not None:
group_ids: list[str] = [tg for tg in task_group.upstream_group_ids or [] if tg is not None]
cache_identifiers_list.extend(group_ids)
if task_group.group_id is not None:
cache_identifiers_list.extend(task_group.group_id)
cache_identifiers_list.extend([task_group.group_id.replace(".", "_")])
cache_identifier = "_".join(cache_identifiers_list)
else:
cache_identifier = dag.dag_id
Expand Down
3 changes: 2 additions & 1 deletion cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ def ensure_profile(
with tempfile.TemporaryDirectory() as temp_dir:
temp_file = Path(temp_dir) / DEFAULT_PROFILES_FILE_NAME
logger.info(
"Creating temporary profiles.yml at %s with the following contents:\n%s",
"Creating temporary profiles.yml with use_mock_values=%s at %s with the following contents:\n%s",
use_mock_values,
temp_file,
profile_contents,
)
Expand Down
12 changes: 0 additions & 12 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,6 @@ def __init__(

validate_initial_user_config(execution_config, profile_config, project_config, render_config, operator_args)

# If we are using the old interface, we should migrate it to the new interface
# This is safe to do now since we have validated which config interface we're using
if project_config.dbt_project_path:
execution_config, render_config = migrate_to_new_interface(execution_config, project_config, render_config)

Expand All @@ -227,16 +225,6 @@ def __init__(

cache_dir = cache.obtain_cache_dir_path(cache_identifier=cache.create_cache_identifier(dag, task_group))

# Previously, we were creating a cosmos.dbt.project.DbtProject
# DbtProject has now been replaced with ProjectConfig directly
# since the interface of the two classes were effectively the same
# Under this previous implementation, we were passing:
# - name, root dir, models dir, snapshots dir and manifest path
# Internally in the dbtProject class, we were defaulting the profile_path
# To be root dir/profiles.yml
# To keep this logic working, if converter is given no ProfileConfig,
# we can create a default retaining this value to preserve this functionality.
# We may want to consider defaulting this value in our actual ProjectConfig class?
self.dbt_graph = DbtGraph(
project=project_config,
render_config=render_config,
Expand Down
4 changes: 2 additions & 2 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def name(self) -> str:
def run_command(command: list[str], tmp_dir: Path, env_vars: dict[str, str]) -> str:
"""Run a command in a subprocess, returning the stdout."""
logger.info("Running command: `%s`", " ".join(command))
logger.info("Environment variable keys: %s", env_vars.keys())
logger.debug("Environment variable keys: %s", env_vars.keys())
process = Popen(
command,
stdout=PIPE,
Expand Down Expand Up @@ -215,7 +215,7 @@ def run_dbt_ls(

stdout = run_command(ls_command, tmp_dir, env_vars)

logger.debug("dbt ls output: %s", stdout)
logger.info("dbt ls output: %s", stdout)
log_filepath = self.log_dir / DBT_LOG_FILENAME
logger.debug("dbt logs available in: %s", log_filepath)
if log_filepath.exists():
Expand Down
2 changes: 1 addition & 1 deletion cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def run_command(

full_cmd = cmd + flags

logger.info("Using environment variables keys: %s", env.keys())
logger.debug("Using environment variables keys: %s", env.keys())

result = self.invoke_dbt(
command=full_cmd,
Expand Down
10 changes: 8 additions & 2 deletions dev/dags/basic_cosmos_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
),
)

shared_execution_config = ExecutionConfig(invocation_mode=InvocationMode.DBT_RUNNER)
shared_execution_config = ExecutionConfig(
invocation_mode=InvocationMode.SUBPROCESS,
# invocation_mode=InvocationMode.DBT_RUNNER
)


@dag(
Expand Down Expand Up @@ -56,7 +59,10 @@ def basic_cosmos_task_group() -> None:
project_config=ProjectConfig(
(DBT_ROOT_PATH / "jaffle_shop").as_posix(),
),
render_config=RenderConfig(select=["path:seeds/raw_orders.csv"]),
render_config=RenderConfig(
select=["path:seeds/raw_orders.csv"],
enable_mock_profile=False, # This is necessary to benefit from partial parsing when using ProfileMapping
),
execution_config=shared_execution_config,
operator_args={"install_deps": True},
profile_config=profile_config,
Expand Down
18 changes: 16 additions & 2 deletions tests/dbt/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,9 @@ def test_load_via_dbt_ls_without_dbt_deps(postgres_profile_config):


@pytest.mark.integration
def test_load_via_dbt_ls_without_dbt_deps_and_preinstalled_dbt_packages(tmp_dbt_project_dir, postgres_profile_config):
def test_load_via_dbt_ls_without_dbt_deps_and_preinstalled_dbt_packages(
tmp_dbt_project_dir, postgres_profile_config, caplog, tmp_path
):
local_flags = [
"--project-dir",
tmp_dbt_project_dir / DBT_PROJECT_NAME,
Expand All @@ -506,17 +508,29 @@ def test_load_via_dbt_ls_without_dbt_deps_and_preinstalled_dbt_packages(tmp_dbt_
stdout, stderr = process.communicate()

project_config = ProjectConfig(dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME)
render_config = RenderConfig(dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME, dbt_deps=False)
render_config = RenderConfig(
dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME, dbt_deps=False, enable_mock_profile=False
)
execution_config = ExecutionConfig(dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME)
dbt_graph = DbtGraph(
project=project_config,
render_config=render_config,
execution_config=execution_config,
profile_config=postgres_profile_config,
cache_dir=tmp_path,
)

from cosmos.constants import DBT_TARGET_DIR_NAME

(tmp_path / DBT_TARGET_DIR_NAME).mkdir(parents=True, exist_ok=True)

dbt_graph.load_via_dbt_ls() # does not raise exception

assert "Unable to do partial parsing" in caplog.text
# TODO: split the caching test into a separate test, and make the following assertion work
# dbt_graph.load_via_dbt_ls() # does not raise exception
# assert not "Unable to do partial parsing" in caplog.text


@pytest.mark.integration
@patch("cosmos.dbt.graph.Popen")
Expand Down
221 changes: 0 additions & 221 deletions tests/plugin/test_plugin.py

This file was deleted.

Loading

0 comments on commit 5c010cc

Please sign in to comment.