diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 3e7961ed1..3a140a235 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -6,7 +6,14 @@ from airflow.models.dag import DAG from airflow.utils.task_group import TaskGroup -from cosmos.constants import DbtResourceType, TestBehavior, ExecutionMode, TESTABLE_DBT_RESOURCES, DEFAULT_DBT_RESOURCES +from cosmos.constants import ( + DbtResourceType, + TestBehavior, + TestIndirectSelection, + ExecutionMode, + TESTABLE_DBT_RESOURCES, + DEFAULT_DBT_RESOURCES, +) from cosmos.core.airflow import get_airflow_task as create_airflow_task from cosmos.core.graph.entities import Task as TaskMetadata from cosmos.dbt.graph import DbtNode @@ -54,6 +61,7 @@ def calculate_leaves(tasks_ids: list[str], nodes: dict[str, DbtNode]) -> list[st def create_test_task_metadata( test_task_name: str, execution_mode: ExecutionMode, + test_indirect_selection: TestIndirectSelection, task_args: dict[str, Any], on_warning_callback: Callable[..., Any] | None = None, node: DbtNode | None = None, @@ -71,6 +79,8 @@ def create_test_task_metadata( """ task_args = dict(task_args) task_args["on_warning_callback"] = on_warning_callback + if test_indirect_selection != TestIndirectSelection.EAGER: + task_args["indirect_selection"] = test_indirect_selection.value if node is not None: if node.resource_type == DbtResourceType.MODEL: task_args["models"] = node.name @@ -144,6 +154,7 @@ def generate_task_or_group( execution_mode: ExecutionMode, task_args: dict[str, Any], test_behavior: TestBehavior, + test_indirect_selection: TestIndirectSelection, on_warning_callback: Callable[..., Any] | None, **kwargs: Any, ) -> BaseOperator | TaskGroup | None: @@ -169,6 +180,7 @@ def generate_task_or_group( test_meta = create_test_task_metadata( "test", execution_mode, + test_indirect_selection, task_args=task_args, node=node, on_warning_callback=on_warning_callback, @@ -187,6 +199,7 @@ def build_airflow_graph( execution_mode: ExecutionMode, # Cosmos-specific - decide what which class to use task_args: dict[str, Any], # Cosmos/DBT - used to instantiate tasks test_behavior: TestBehavior, # Cosmos-specific: how to inject tests to Airflow DAG + test_indirect_selection: TestIndirectSelection, # Cosmos/DBT - used to set test indirect selection mode dbt_project_name: str, # DBT / Cosmos - used to name test task if mode is after_all, task_group: TaskGroup | None = None, on_warning_callback: Callable[..., Any] | None = None, # argument specific to the DBT test command @@ -235,6 +248,7 @@ def build_airflow_graph( execution_mode=execution_mode, task_args=task_args, test_behavior=test_behavior, + test_indirect_selection=test_indirect_selection, on_warning_callback=on_warning_callback, node=node, ) @@ -246,7 +260,11 @@ def build_airflow_graph( # The end of a DAG is defined by the DAG leaf tasks (tasks which do not have downstream tasks) if test_behavior == TestBehavior.AFTER_ALL: test_meta = create_test_task_metadata( - f"{dbt_project_name}_test", execution_mode, task_args=task_args, on_warning_callback=on_warning_callback + f"{dbt_project_name}_test", + execution_mode, + test_indirect_selection, + task_args=task_args, + on_warning_callback=on_warning_callback, ) test_task = create_airflow_task(test_meta, dag, task_group=task_group) leaves_ids = calculate_leaves(tasks_ids=list(tasks_map.keys()), nodes=nodes) diff --git a/cosmos/config.py b/cosmos/config.py index fa68b44ab..e052cb9db 100644 --- a/cosmos/config.py +++ b/cosmos/config.py @@ -9,7 +9,7 @@ from pathlib import Path from typing import Any, Iterator, Callable -from cosmos.constants import DbtResourceType, TestBehavior, ExecutionMode, LoadMode +from cosmos.constants import DbtResourceType, TestBehavior, ExecutionMode, LoadMode, TestIndirectSelection from cosmos.dbt.executable import get_system_dbt from cosmos.exceptions import CosmosValueError from cosmos.log import get_logger @@ -205,9 +205,11 @@ class ExecutionConfig: Contains configuration about how to execute dbt. :param execution_mode: The execution mode for dbt. Defaults to local + :param test_indirect_selection: The mode to configure the test behavior when performing indirect selection. :param dbt_executable_path: The path to the dbt executable. Defaults to dbt if available on the path. """ execution_mode: ExecutionMode = ExecutionMode.LOCAL + test_indirect_selection: TestIndirectSelection = TestIndirectSelection.EAGER dbt_executable_path: str | Path = get_system_dbt() diff --git a/cosmos/constants.py b/cosmos/constants.py index cd59c8173..9aa38c34e 100644 --- a/cosmos/constants.py +++ b/cosmos/constants.py @@ -51,6 +51,17 @@ class ExecutionMode(Enum): VIRTUALENV = "virtualenv" +class TestIndirectSelection(Enum): + """ + Modes to configure the test behavior when performing indirect selection. + """ + + EAGER = "eager" + CAUTIOUS = "cautious" + BUILDABLE = "buildable" + EMPTY = "empty" + + class DbtResourceType(aenum.Enum): # type: ignore """ Type of dbt node. diff --git a/cosmos/converter.py b/cosmos/converter.py index 8137da3ec..43865c56b 100644 --- a/cosmos/converter.py +++ b/cosmos/converter.py @@ -116,6 +116,7 @@ def __init__( exclude = render_config.exclude dbt_deps = render_config.dbt_deps execution_mode = execution_config.execution_mode + test_indirect_selection = execution_config.test_indirect_selection load_mode = render_config.load_method manifest_path = project_config.parsed_manifest_path dbt_executable_path = execution_config.dbt_executable_path @@ -167,6 +168,7 @@ def __init__( execution_mode=execution_mode, task_args=task_args, test_behavior=test_behavior, + test_indirect_selection=test_indirect_selection, dbt_project_name=dbt_project.name, on_warning_callback=on_warning_callback, node_converters=node_converters, diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index d43a2d241..6d276013d 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -89,6 +89,7 @@ def __init__( vars: dict[str, str] | None = None, models: str | None = None, emit_datasets: bool = True, + indirect_selection: str | None = None, cache_selected_only: bool = False, no_version_check: bool = False, fail_fast: bool = False, @@ -115,6 +116,7 @@ def __init__( self.vars = vars self.models = models self.emit_datasets = emit_datasets + self.indirect_selection = indirect_selection self.cache_selected_only = cache_selected_only self.no_version_check = no_version_check self.fail_fast = fail_fast @@ -213,6 +215,9 @@ def build_cmd( if self.base_cmd: dbt_cmd.extend(self.base_cmd) + if self.indirect_selection: + dbt_cmd += ["--indirect-selection", self.indirect_selection] + dbt_cmd.extend(self.add_global_flags()) # add command specific flags diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index 2eb93c613..6bc244b6b 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -17,7 +17,7 @@ generate_task_or_group, ) from cosmos.config import ProfileConfig -from cosmos.constants import DbtResourceType, ExecutionMode, TestBehavior +from cosmos.constants import DbtResourceType, ExecutionMode, TestBehavior, TestIndirectSelection from cosmos.dbt.graph import DbtNode from cosmos.profiles import PostgresUserPasswordProfileMapping @@ -80,6 +80,7 @@ def test_build_airflow_graph_with_after_each(): nodes=sample_nodes, dag=dag, execution_mode=ExecutionMode.LOCAL, + test_indirect_selection=TestIndirectSelection.EAGER, task_args=task_args, test_behavior=TestBehavior.AFTER_EACH, dbt_project_name="astro_shop", @@ -129,6 +130,7 @@ def test_create_task_group_for_after_each_supported_nodes(node_type, task_suffix task_group=None, node=node, execution_mode=ExecutionMode.LOCAL, + test_indirect_selection=TestIndirectSelection.EAGER, task_args={ "project_dir": SAMPLE_PROJ_PATH, "profile_config": ProfileConfig( @@ -170,6 +172,7 @@ def test_build_airflow_graph_with_after_all(): nodes=sample_nodes, dag=dag, execution_mode=ExecutionMode.LOCAL, + test_indirect_selection=TestIndirectSelection.EAGER, task_args=task_args, test_behavior=TestBehavior.AFTER_ALL, dbt_project_name="astro_shop", @@ -332,15 +335,30 @@ def test_create_task_metadata_snapshot(caplog): @pytest.mark.parametrize( - "node_type,node_unique_id,selector_key,selector_value", + "node_type,node_unique_id,test_indirect_selection,additional_arguments", [ - (DbtResourceType.MODEL, "node_name", "models", "node_name"), - (DbtResourceType.SEED, "node_name", "select", "node_name"), - (DbtResourceType.SOURCE, "source.node_name", "select", "source:node_name"), - (DbtResourceType.SNAPSHOT, "node_name", "select", "node_name"), + (DbtResourceType.MODEL, "node_name", TestIndirectSelection.EAGER, {"models": "node_name"}), + ( + DbtResourceType.SEED, + "node_name", + TestIndirectSelection.CAUTIOUS, + {"select": "node_name", "indirect_selection": "cautious"}, + ), + ( + DbtResourceType.SOURCE, + "source.node_name", + TestIndirectSelection.BUILDABLE, + {"select": "source:node_name", "indirect_selection": "buildable"}, + ), + ( + DbtResourceType.SNAPSHOT, + "node_name", + TestIndirectSelection.EMPTY, + {"select": "node_name", "indirect_selection": "empty"}, + ), ], ) -def test_create_test_task_metadata(node_type, node_unique_id, selector_key, selector_value): +def test_create_test_task_metadata(node_type, node_unique_id, test_indirect_selection, additional_arguments): sample_node = DbtNode( name="node_name", unique_id=node_unique_id, @@ -353,10 +371,17 @@ def test_create_test_task_metadata(node_type, node_unique_id, selector_key, sele metadata = create_test_task_metadata( test_task_name="test_no_nulls", execution_mode=ExecutionMode.LOCAL, + test_indirect_selection=test_indirect_selection, task_args={"task_arg": "value"}, on_warning_callback=True, node=sample_node, ) assert metadata.id == "test_no_nulls" assert metadata.operator_class == "cosmos.operators.local.DbtTestLocalOperator" - assert metadata.arguments == {"task_arg": "value", "on_warning_callback": True, selector_key: selector_value} + assert metadata.arguments == { + **{ + "task_arg": "value", + "on_warning_callback": True, + }, + **additional_arguments, + } diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index e5f210c53..07c186d4a 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -117,6 +117,29 @@ def test_dbt_base_operator_add_user_supplied_global_flags() -> None: assert cmd[-1] == "run" +@pytest.mark.parametrize( + "indirect_selection_type", + [None, "cautious", "buildable", "empty"], +) +def test_dbt_base_operator_use_indirect_selection(indirect_selection_type) -> None: + dbt_base_operator = DbtLocalBaseOperator( + profile_config=profile_config, + task_id="my-task", + project_dir="my/dir", + base_cmd=["run"], + indirect_selection=indirect_selection_type, + ) + + cmd, _ = dbt_base_operator.build_cmd( + Context(execution_date=datetime(2023, 2, 15, 12, 30)), + ) + if indirect_selection_type: + assert cmd[-2] == "--indirect-selection" + assert cmd[-1] == indirect_selection_type + else: + assert cmd == ["dbt", "run"] + + @pytest.mark.parametrize( ["skip_exception", "exception_code_returned", "expected_exception"], [