Skip to content

Commit

Permalink
Support ProjectConfig.dbt_project_path = None & different paths for…
Browse files Browse the repository at this point in the history
… Rendering and Execution (#634)

This MR finishes the work that was started in #605 to add full support
for ProjectConfig.dbt_project_path = None, and implements #568.

Within this PR, several things have been updated:
1 - Added project_path fields to RenderConfig and ExecutionConfig
2 - Simplified the consumption of RenderConfig in the dbtGraph class
3 - added option to configure different dbt executables for Rendering vs
Execution.

Closes: #568
  • Loading branch information
MrBones757 authored and tatiana committed Jul 29, 2024
1 parent cd47ddc commit ad54b59
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 81 deletions.
2 changes: 1 addition & 1 deletion cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ class ExecutionConfig:

dbt_project_path: InitVar[str | Path | None] = None
virtualenv_dir: str | Path | None = None

project_path: Path | None = field(init=False)

def __post_init__(self, dbt_project_path: str | Path | None) -> None:
Expand All @@ -385,4 +386,3 @@ def __post_init__(self, dbt_project_path: str | Path | None) -> None:
)
self.invocation_mode = InvocationMode.SUBPROCESS
self.project_path = Path(dbt_project_path) if dbt_project_path else None

8 changes: 4 additions & 4 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,10 @@ def __init__(

validate_adapted_user_config(execution_config, project_config, render_config)

env_vars = project_config.env_vars or operator_args.get("env")
dbt_vars = project_config.dbt_vars or operator_args.get("vars")
env_vars = copy.deepcopy(project_config.env_vars or operator_args.get("env"))
dbt_vars = copy.deepcopy(project_config.dbt_vars or operator_args.get("vars"))

if execution_mode != ExecutionMode.VIRTUALENV and execution_config.virtualenv_dir is not None:
if execution_config.execution_mode != ExecutionMode.VIRTUALENV and execution_config.virtualenv_dir is not None:
logger.warning(
"`ExecutionConfig.virtualenv_dir` is only supported when \
ExecutionConfig.execution_mode is set to ExecutionMode.VIRTUALENV."
Expand Down Expand Up @@ -288,7 +288,7 @@ def __init__(
task_args,
execution_mode=execution_config.execution_mode,
)
if (execution_mode == ExecutionMode.VIRTUALENV and execution_config.virtualenv_dir is not None):
if execution_config.execution_mode == ExecutionMode.VIRTUALENV and execution_config.virtualenv_dir is not None:
task_args["virtualenv_dir"] = execution_config.virtualenv_dir

build_airflow_graph(
Expand Down
27 changes: 26 additions & 1 deletion cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ExecutionMode,
LoadMode,
)
from cosmos.dbt.executable import get_system_dbt
from cosmos.dbt.parser.project import LegacyDbtProject
from cosmos.dbt.project import create_symlinks, environ, get_partial_parse_path, has_non_empty_dependencies_file
from cosmos.dbt.selector import select_nodes
Expand Down Expand Up @@ -99,6 +100,14 @@ def context_dict(self) -> dict[str, Any]:
}


def create_symlinks(project_path: Path, tmp_dir: Path) -> None:
"""Helper function to create symlinks to the dbt project files."""
ignore_paths = (DBT_LOG_DIR_NAME, DBT_TARGET_DIR_NAME, "dbt_packages", "profiles.yml")
for child_name in os.listdir(project_path):
if child_name not in ignore_paths:
os.symlink(project_path / child_name, tmp_dir / child_name)


def run_command(command: list[str], tmp_dir: Path, env_vars: dict[str, str]) -> str:
"""Run a command in a subprocess, returning the stdout."""
logger.info("Running command: `%s`", " ".join(command))
Expand Down Expand Up @@ -154,6 +163,15 @@ class DbtGraph:
Supports different ways of loading the `dbt` project into this representation.
Different loading methods can result in different `nodes` and `filtered_nodes`.
Example of how to use:
dbt_graph = DbtGraph(
project=ProjectConfig(dbt_project_path=DBT_PROJECT_PATH),
render_config=RenderConfig(exclude=["*orders*"], select=[]),
dbt_cmd="/usr/local/bin/dbt"
)
dbt_graph.load(method=LoadMode.DBT_LS, execution_mode=ExecutionMode.LOCAL)
"""

nodes: dict[str, DbtNode] = dict()
Expand All @@ -170,6 +188,8 @@ def __init__(
cache_identifier: str = "",
dbt_vars: dict[str, str] | None = None,
airflow_metadata: dict[str, str] | None = None,
dbt_cmd: str = get_system_dbt(),
operator_args: dict[str, Any] | None = None,
):
self.project = project
self.render_config = render_config
Expand All @@ -182,6 +202,8 @@ def __init__(
else:
self.dbt_ls_cache_key = ""
self.dbt_vars = dbt_vars or {}
self.operator_args = operator_args or {}
self.dbt_cmd = dbt_cmd

@cached_property
def env_vars(self) -> dict[str, str]:
Expand Down Expand Up @@ -347,6 +369,7 @@ def load(
def run_dbt_ls(
self, dbt_cmd: str, project_path: Path, tmp_dir: Path, env_vars: dict[str, str]
) -> dict[str, DbtNode]:

"""Runs dbt ls command and returns the parsed nodes."""
ls_command = [dbt_cmd, "ls", "--output", "json"]

Expand Down Expand Up @@ -446,6 +469,9 @@ def load_via_dbt_ls_without_cache(self) -> None:
if not self.profile_config:
raise CosmosLoadDbtException("Unable to load project via dbt ls without a profile config.")

if not self.profile_config:
raise CosmosLoadDbtException("Unable to load project via dbt ls without a profile config.")

with tempfile.TemporaryDirectory() as tmpdir:
logger.debug(f"Content of the dbt project dir {project_path}: `{os.listdir(project_path)}`")
tmpdir_path = Path(tmpdir)
Expand Down Expand Up @@ -490,7 +516,6 @@ def load_via_dbt_ls_without_cache(self) -> None:
self.run_dbt_deps(dbt_cmd, tmpdir_path, env)

nodes = self.run_dbt_ls(dbt_cmd, self.project_path, tmpdir_path, env)

self.nodes = nodes
self.filtered_nodes = nodes

Expand Down
13 changes: 6 additions & 7 deletions tests/dbt/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from cosmos import settings
from cosmos.config import CosmosConfigException, ExecutionConfig, ProfileConfig, ProjectConfig, RenderConfig
from cosmos.constants import DBT_TARGET_DIR_NAME, DbtResourceType, ExecutionMode

from cosmos.dbt.graph import (
CosmosLoadDbtException,
DbtGraph,
Expand Down Expand Up @@ -498,8 +499,9 @@ def test_load_via_dbt_ls_without_exclude(project_name, postgres_profile_config):
def test_load_via_custom_without_project_path():
project_config = ProjectConfig(manifest_path=SAMPLE_MANIFEST, project_name="test")
execution_config = ExecutionConfig()
render_config = RenderConfig(dbt_executable_path="/inexistent/dbt")
render_config = RenderConfig()
dbt_graph = DbtGraph(
dbt_cmd="/inexistent/dbt",
project=project_config,
execution_config=execution_config,
render_config=render_config,
Expand All @@ -515,10 +517,9 @@ def test_load_via_custom_without_project_path():
def test_load_via_dbt_ls_without_profile(mock_validate_dbt_command):
project_config = ProjectConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME)
execution_config = ExecutionConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME)
render_config = RenderConfig(
dbt_executable_path="existing-dbt-cmd", dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME
)
render_config = RenderConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME)
dbt_graph = DbtGraph(
dbt_cmd="/inexistent/dbt",
project=project_config,
execution_config=execution_config,
render_config=render_config,
Expand All @@ -534,9 +535,7 @@ def test_load_via_dbt_ls_without_profile(mock_validate_dbt_command):
def test_load_via_dbt_ls_with_invalid_dbt_path(mock_which):
project_config = ProjectConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME)
execution_config = ExecutionConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME)
render_config = RenderConfig(
dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME, dbt_executable_path="/inexistent/dbt"
)
render_config = RenderConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME)
with patch("pathlib.Path.exists", return_value=True):
dbt_graph = DbtGraph(
project=project_config,
Expand Down
68 changes: 0 additions & 68 deletions tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,74 +270,6 @@ def test_converter_fails_execution_config_no_project_dir(mock_load_dbt_graph, ex
)


def test_converter_fails_render_config_invalid_dbt_path_with_dbt_ls():
"""
Validate that a dbt project fails to be rendered to Airflow with DBT_LS if
the dbt command is invalid.
"""
project_config = ProjectConfig(dbt_project_path=SAMPLE_DBT_PROJECT.as_posix(), project_name="sample")
execution_config = ExecutionConfig(
execution_mode=ExecutionMode.LOCAL,
dbt_executable_path="invalid-execution-dbt",
)
render_config = RenderConfig(
emit_datasets=True,
dbt_executable_path="invalid-render-dbt",
)
profile_config = ProfileConfig(
profile_name="my_profile_name",
target_name="my_target_name",
profiles_yml_filepath=SAMPLE_PROFILE_YML,
)
with pytest.raises(CosmosConfigException) as err_info:
with DAG("test-id", start_date=datetime(2022, 1, 1)) as dag:
DbtToAirflowConverter(
dag=dag,
nodes=nodes,
project_config=project_config,
profile_config=profile_config,
execution_config=execution_config,
render_config=render_config,
)
assert (
err_info.value.args[0]
== "Unable to find the dbt executable, attempted: <invalid-render-dbt> and <invalid-execution-dbt>."
)


def test_converter_fails_render_config_invalid_dbt_path_with_manifest():
"""
Validate that a dbt project succeeds to be rendered to Airflow with DBT_MANIFEST even when
the dbt command is invalid.
"""
project_config = ProjectConfig(manifest_path=SAMPLE_DBT_MANIFEST.as_posix(), project_name="sample")

execution_config = ExecutionConfig(
execution_mode=ExecutionMode.LOCAL,
dbt_executable_path="invalid-execution-dbt",
dbt_project_path=SAMPLE_DBT_PROJECT.as_posix(),
)
render_config = RenderConfig(
emit_datasets=True,
dbt_executable_path="invalid-render-dbt",
)
profile_config = ProfileConfig(
profile_name="my_profile_name",
target_name="my_target_name",
profiles_yml_filepath=SAMPLE_PROFILE_YML,
)
with DAG("test-id", start_date=datetime(2022, 1, 1)) as dag:
converter = DbtToAirflowConverter(
dag=dag,
nodes=nodes,
project_config=project_config,
profile_config=profile_config,
execution_config=execution_config,
render_config=render_config,
)
assert converter


@pytest.mark.parametrize(
"execution_mode,operator_args",
[
Expand Down

0 comments on commit ad54b59

Please sign in to comment.