diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 6605bf20d..3e3103266 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -99,6 +99,7 @@ def create_test_task_metadata( extra_context = {} task_owner = "" + airflow_task_config = {} if test_indirect_selection != TestIndirectSelection.EAGER: task_args["indirect_selection"] = test_indirect_selection.value if node is not None: @@ -111,6 +112,7 @@ def create_test_task_metadata( extra_context = {"dbt_node_config": node.context_dict} task_owner = node.owner + airflow_task_config = node.airflow_task_config elif render_config is not None: # TestBehavior.AFTER_ALL task_args["select"] = render_config.select @@ -120,6 +122,7 @@ def create_test_task_metadata( return TaskMetadata( id=test_task_name, owner=task_owner, + airflow_task_config=airflow_task_config, operator_class=calculate_operator_class( execution_mode=execution_mode, dbt_class="DbtTest", @@ -214,6 +217,7 @@ def create_task_metadata( task_metadata = TaskMetadata( id=task_id, owner=node.owner, + airflow_task_config=node.airflow_task_config, operator_class=calculate_operator_class( execution_mode=execution_mode, dbt_class=dbt_resource_to_class[node.resource_type] ), diff --git a/cosmos/core/airflow.py b/cosmos/core/airflow.py index 6f1064649..e25404aed 100644 --- a/cosmos/core/airflow.py +++ b/cosmos/core/airflow.py @@ -32,6 +32,9 @@ def get_airflow_task(task: Task, dag: DAG, task_group: TaskGroup | None = None) if task.owner != "": task_kwargs["owner"] = task.owner + for k, v in task.airflow_task_config.items(): + task_kwargs[k] = v + airflow_task = Operator( task_id=task.id, dag=dag, diff --git a/cosmos/core/graph/entities.py b/cosmos/core/graph/entities.py index 6bf9ff046..cdd5485a6 100644 --- a/cosmos/core/graph/entities.py +++ b/cosmos/core/graph/entities.py @@ -58,6 +58,7 @@ class Task(CosmosEntity): """ owner: str = "" + airflow_task_config: Dict[str, Any] = field(default_factory=dict) operator_class: str = "airflow.operators.empty.EmptyOperator" arguments: Dict[str, Any] = field(default_factory=dict) extra_context: Dict[str, Any] = field(default_factory=dict) diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index 6c207f2ca..04a7425e7 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -13,7 +13,7 @@ from functools import cached_property from pathlib import Path from subprocess import PIPE, Popen -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional from airflow.models import Variable @@ -67,6 +67,33 @@ class DbtNode: has_freshness: bool = False has_test: bool = False + @property + def airflow_task_config(self) -> Dict[str, Any]: + """ + This method is designed to extend the dbt project's functionality by incorporating Airflow-related metadata into the dbt YAML configuration. + Since dbt projects are independent of Airflow, adding Airflow-specific information to the `meta` field within the dbt YAML allows Airflow tasks to + utilize this information during execution. + + Examples: pool, pool_slots, queue, ... + Returns: + Dict[str, Any]: A dictionary containing custom metadata configurations for integration with Airflow. + """ + + if "meta" in self.config: + meta = self.config["meta"] + if "cosmos" in meta: + cosmos = meta["cosmos"] + if isinstance(cosmos, dict): + if "operator_kwargs" in cosmos: + operator_kwargs = cosmos["operator_kwargs"] + if isinstance(operator_kwargs, dict): + return operator_kwargs + else: + logger.error(f"Invalid type: 'operator_kwargs' in meta.cosmos must be a dict.") + else: + logger.error(f"Invalid type: 'cosmos' in meta must be a dict.") + return {} + @property def resource_name(self) -> str: """ diff --git a/docs/_static/custom_airflow_pool.png b/docs/_static/custom_airflow_pool.png new file mode 100644 index 000000000..4b4163e66 Binary files /dev/null and b/docs/_static/custom_airflow_pool.png differ diff --git a/docs/getting_started/custom-airflow-properties.rst b/docs/getting_started/custom-airflow-properties.rst new file mode 100644 index 000000000..90490a099 --- /dev/null +++ b/docs/getting_started/custom-airflow-properties.rst @@ -0,0 +1,33 @@ +.. _custom-airflow-properties: + +Airflow Configuration Overrides with Astronomer Cosmos +====================================================== + +**Astronomer Cosmos** allows you to override Airflow configurations for each dbt task (dbt operator) via the dbt YAML file. + +Sample dbt Model YAML +++++++++++++ + +.. code-block:: yaml + + version: 2 + models: + - name: name + description: description + meta: + cosmos: + operator_args: + pool: abcd + + + + +Explanation +++++++++++++ + +By adding Airflow configurations under **cosmos** in the **meta** field, you can set independent Airflow configurations for each task. +For example, in the YAML above, the **pool** setting is applied to the specific dbt task. +This approach allows for more granular control over Airflow settings per task within your dbt model definitions. + +.. image:: ../_static/custom_airflow_pool.png + :alt: Result of applying Custom Airflow Pool diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index d2f943bab..fc0070e8b 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -62,7 +62,7 @@ depends_on=[parent_node.unique_id], file_path=SAMPLE_PROJ_PATH / "gen3/models/child.sql", tags=["nightly"], - config={"materialized": "table"}, + config={"materialized": "table", "meta": {"cosmos": {"operator_kwargs": {"queue": "custom_queue"}}}}, ) child2_node = DbtNode( @@ -71,7 +71,7 @@ depends_on=[parent_node.unique_id], file_path=SAMPLE_PROJ_PATH / "gen3/models/child2_v2.sql", tags=["nightly"], - config={"materialized": "table"}, + config={"materialized": "table", "meta": {"cosmos": {"operator_kwargs": {"pool": "custom_pool"}}}}, ) sample_nodes_list = [parent_seed, parent_node, test_parent_node, child_node, child2_node] @@ -750,3 +750,42 @@ def test_owner(dbt_extra_config, expected_owner): assert len(output.leaves) == 1 assert output.leaves[0].owner == expected_owner + + +def test_custom_meta(): + with DAG("test-id", start_date=datetime(2022, 1, 1)) as dag: + task_args = { + "project_dir": SAMPLE_PROJ_PATH, + "conn_id": "fake_conn", + "profile_config": ProfileConfig( + profile_name="default", + target_name="default", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="fake_conn", + profile_args={"schema": "public"}, + ), + ), + } + build_airflow_graph( + nodes=sample_nodes, + dag=dag, + execution_mode=ExecutionMode.LOCAL, + test_indirect_selection=TestIndirectSelection.EAGER, + task_args=task_args, + render_config=RenderConfig( + test_behavior=TestBehavior.AFTER_EACH, + source_rendering_behavior=SOURCE_RENDERING_BEHAVIOR, + ), + dbt_project_name="astro_shop", + ) + # test custom meta (queue, pool) + for task in dag.tasks: + if task.task_id == "child2_v2_run": + assert task.pool == "custom_pool" + else: + assert task.pool == "default_pool" + + if task.task_id == "child_run": + assert task.queue == "custom_queue" + else: + assert task.queue == "default"