Skip to content

Commit

Permalink
To support task display_name
Browse files Browse the repository at this point in the history
Co-authored-by: t.kodama <[email protected]>
  • Loading branch information
pankajastro and t0momi219 committed Dec 19, 2024
1 parent c5edba0 commit 2f5dd93
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 12 deletions.
49 changes: 37 additions & 12 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,27 @@ def create_dbt_resource_to_class(test_behavior: TestBehavior) -> dict[str, str]:
return dbt_resource_to_class


def _get_task_id_and_args(
node: DbtNode,
args: dict[str, Any],
use_task_group: bool,
normalize_task_id: Callable[..., Any] | None,
resource_suffix: str,
) -> tuple[str, dict[str, Any]]:
"""
Generate task ID and update args with display name if needed.
"""
args_update = args
if use_task_group:
task_id = resource_suffix
elif normalize_task_id:
task_id = normalize_task_id(node)
args_update["task_display_name"] = f"{node.name}_{resource_suffix}"
else:
task_id = f"{node.name}_{resource_suffix}"
return task_id, args_update


def create_task_metadata(
node: DbtNode,
execution_mode: ExecutionMode,
Expand All @@ -165,6 +186,7 @@ def create_task_metadata(
use_task_group: bool = False,
source_rendering_behavior: SourceRenderingBehavior = SourceRenderingBehavior.NONE,
test_behavior: TestBehavior = TestBehavior.AFTER_ALL,
normalize_task_id: Callable[..., Any] | None = None,
) -> TaskMetadata | None:
"""
Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node.
Expand All @@ -188,31 +210,30 @@ def create_task_metadata(
"dbt_dag_task_group_identifier": dbt_dag_task_group_identifier,
}
if test_behavior == TestBehavior.BUILD and node.resource_type in SUPPORTED_BUILD_RESOURCES:
task_id = f"{node.name}_{node.resource_type.value}_build"
task_id, args = _get_task_id_and_args(node, args, use_task_group, normalize_task_id, "build")
elif node.resource_type == DbtResourceType.MODEL:
if use_task_group:
task_id = "run"
else:
task_id = f"{node.name}_run"
task_id, args = _get_task_id_and_args(node, args, use_task_group, normalize_task_id, "run")
elif node.resource_type == DbtResourceType.SOURCE:
if (source_rendering_behavior == SourceRenderingBehavior.NONE) or (
source_rendering_behavior == SourceRenderingBehavior.WITH_TESTS_OR_FRESHNESS
and node.has_freshness is False
and node.has_test is False
):
return None
task_id = f"{node.name}_source"
args["select"] = f"source:{node.resource_name}"
args.pop("models")
if use_task_group is True:
task_id = node.resource_type.value
task_id, args = _get_task_id_and_args(node, args, use_task_group, normalize_task_id, "source")
if node.has_freshness is False and source_rendering_behavior == SourceRenderingBehavior.ALL:
# render sources without freshness as empty operators
return TaskMetadata(id=task_id, operator_class="airflow.operators.empty.EmptyOperator")
if "task_display_name" in args:
args = {"task_display_name": args["task_display_name"]}
else:
args = {}
return TaskMetadata(id=task_id, operator_class="airflow.operators.empty.EmptyOperator", arguments=args)
else:
task_id = f"{node.name}_{node.resource_type.value}"
if use_task_group is True:
task_id = node.resource_type.value
task_id, args = _get_task_id_and_args(
node, args, use_task_group, normalize_task_id, node.resource_type.value
)

task_metadata = TaskMetadata(
id=task_id,
Expand Down Expand Up @@ -244,6 +265,7 @@ def generate_task_or_group(
source_rendering_behavior: SourceRenderingBehavior,
test_indirect_selection: TestIndirectSelection,
on_warning_callback: Callable[..., Any] | None,
normalize_task_id: Callable[..., Any] | None = None,
**kwargs: Any,
) -> BaseOperator | TaskGroup | None:
task_or_group: BaseOperator | TaskGroup | None = None
Expand All @@ -262,6 +284,7 @@ def generate_task_or_group(
use_task_group=use_task_group,
source_rendering_behavior=source_rendering_behavior,
test_behavior=test_behavior,
normalize_task_id=normalize_task_id,
)

# In most cases, we'll map one DBT node to one Airflow task
Expand Down Expand Up @@ -364,6 +387,7 @@ def build_airflow_graph(
node_converters = render_config.node_converters or {}
test_behavior = render_config.test_behavior
source_rendering_behavior = render_config.source_rendering_behavior
normalize_task_id = render_config.normalize_task_id
tasks_map: dict[str, Union[TaskGroup, BaseOperator]] = {}
task_or_group: TaskGroup | BaseOperator

Expand All @@ -385,6 +409,7 @@ def build_airflow_graph(
source_rendering_behavior=source_rendering_behavior,
test_indirect_selection=test_indirect_selection,
on_warning_callback=on_warning_callback,
normalize_task_id=normalize_task_id,
node=node,
)
if task_or_group is not None:
Expand Down
3 changes: 3 additions & 0 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class RenderConfig:
:param dbt_ls_path: Configures the location of an output of ``dbt ls``. Required when using ``load_method=LoadMode.DBT_LS_FILE``.
:param enable_mock_profile: Allows to enable/disable mocking profile. Enabled by default. Mock profiles are useful for parsing Cosmos DAGs in the CI, but should be disabled to benefit from partial parsing (since Cosmos 1.4).
:param source_rendering_behavior: Determines how source nodes are rendered when using cosmos default source node rendering (ALL, NONE, WITH_TESTS_OR_FRESHNESS). Defaults to "NONE" (since Cosmos 1.6).
:param airflow_vars_to_purge_dbt_ls_cache: Specify Airflow variables that will affect the LoadMode.DBT_LS cache.
:param normalize_task_id: A callable that takes a dbt node as input and returns the task ID. This allows users to assign a custom node ID separate from the display name.
"""

emit_datasets: bool = True
Expand All @@ -80,6 +82,7 @@ class RenderConfig:
enable_mock_profile: bool = True
source_rendering_behavior: SourceRenderingBehavior = SourceRenderingBehavior.NONE
airflow_vars_to_purge_dbt_ls_cache: list[str] = field(default_factory=list)
normalize_task_id: Callable[..., Any] | None = None

def __post_init__(self, dbt_project_path: str | Path | None) -> None:
if self.env_vars:
Expand Down
1 change: 1 addition & 0 deletions docs/configuration/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ Cosmos offers a number of configuration options to customize its behavior. For m
Logging <logging>
Caching <caching>
Callbacks <callbacks>
Task display name <task-display-name>
3 changes: 3 additions & 0 deletions docs/configuration/render-config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ The ``RenderConfig`` class takes the following arguments:
- ``env_vars``: (available in v1.2.5, use``ProjectConfig.env_vars`` for v1.3.0 onwards) A dictionary of environment variables for rendering. Only supported when using ``load_method=LoadMode.DBT_LS``.
- ``dbt_project_path``: Configures the DBT project location accessible on their airflow controller for DAG rendering - Required when using ``load_method=LoadMode.DBT_LS`` or ``load_method=LoadMode.CUSTOM``
- ``airflow_vars_to_purge_cache``: (new in v1.5) Specify Airflow variables that will affect the ``LoadMode.DBT_LS`` cache. See `Caching <./caching.html>`_ for more information.
- ``source_rendering_behavior``: Determines how source nodes are rendered when using cosmos default source node rendering (ALL, NONE, WITH_TESTS_OR_FRESHNESS). Defaults to "NONE" (since Cosmos 1.6). See `Source Nodes Rendering <./source-nodes-rendering.html>`_ for more information.
- ``normalize_task_id``: A callable that takes a dbt node as input and returns the task ID. This function allows users to set a custom task_id independently of the model name, which can be specified as the task’s display_name. This way, task_id can be modified using a user-defined function, while the model name remains as the task’s display name. The display_name parameter is available in Airflow 2.9 and above. See `Task display name <./task-display-name.html>`_ for more information.


Customizing how nodes are rendered (experimental)
-------------------------------------------------
Expand Down
33 changes: 33 additions & 0 deletions docs/configuration/task-display-name.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
.. _task-display-name:

Task display name
================

.. note::
This feature is only available for Airflow >= 2.9.

In Airflow, ``task_id`` does not support non-ASCII characters. Therefore, if users wish to use non-ASCII characters (such as their native language) as display names while keeping ``task_id`` in ASCII, they can use the ``display_name`` parameter.

To work with projects that use non-ASCII characters in model names, the ``normalize_task_id`` field of ``RenderConfig`` can be utilized.

Example:

You can provide a function to convert the model name to an ASCII-compatible format. The function’s output is used as the TaskID, while the display name on Airflow remains as the original model name.

.. code-block:: python
from slugify import slugify
def normalize_task_id(node):
return slugify(node.name)
from cosmos import DbtTaskGroup, RenderConfig
jaffle_shop = DbtTaskGroup(
render_config=RenderConfig(normalize_task_id=normalize_task_id)
)
.. note::
Although the slugify example often works, it may not be suitable for use in actual production. Since slugify performs conversions based on pronunciation, there may be cases where task_id is not unique due to homophones and similar issues.
117 changes: 117 additions & 0 deletions tests/airflow/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,123 @@ def test_create_task_metadata_snapshot(caplog):
assert metadata.arguments == {"models": "my_snapshot"}


def _normalize_task_id(node: DbtNode) -> str:
"""for test_create_task_metadata_normalize_task_id"""
return f"new_task_id_{node.name}_{node.resource_type.value}"


@pytest.mark.skipif(
version.parse(airflow_version) < version.parse("2.9"),
reason="Airflow task did not have display_name until the 2.9 release",
)
@pytest.mark.parametrize(
"node_type,node_id,normalize_task_id,use_task_group,expected_node_id,expected_display_name",
[
# normalize_task_id is None (default)
(
DbtResourceType.MODEL,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
None,
False,
"test_node_run",
None,
),
(
DbtResourceType.SOURCE,
f"{DbtResourceType.SOURCE.value}.my_folder.test_node",
None,
False,
"test_node_source",
None,
),
(
DbtResourceType.SEED,
f"{DbtResourceType.SEED.value}.my_folder.test_node",
None,
False,
"test_node_seed",
None,
),
# normalize_task_id is passed and use_task_group is False
(
DbtResourceType.MODEL,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
_normalize_task_id,
False,
"new_task_id_test_node_model",
"test_node_run",
),
(
DbtResourceType.SOURCE,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
_normalize_task_id,
False,
"new_task_id_test_node_source",
"test_node_source",
),
(
DbtResourceType.SEED,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
_normalize_task_id,
False,
"new_task_id_test_node_seed",
"test_node_seed",
),
# normalize_task_id is passed and use_task_group is True
(
DbtResourceType.MODEL,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
_normalize_task_id,
True,
"run",
None,
),
(
DbtResourceType.SOURCE,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
_normalize_task_id,
True,
"source",
None,
),
(
DbtResourceType.SEED,
f"{DbtResourceType.MODEL.value}.my_folder.test_node",
_normalize_task_id,
True,
"seed",
None,
),
],
)
def test_create_task_metadata_normalize_task_id(
node_type, node_id, normalize_task_id, use_task_group, expected_node_id, expected_display_name
):
node = DbtNode(
unique_id=node_id,
resource_type=node_type,
depends_on=[],
file_path="",
tags=[],
config={},
)
args = {}
metadata = create_task_metadata(
node,
execution_mode=ExecutionMode.LOCAL,
args=args,
dbt_dag_task_group_identifier="",
use_task_group=use_task_group,
normalize_task_id=normalize_task_id,
source_rendering_behavior=SourceRenderingBehavior.ALL,
)
assert metadata.id == expected_node_id
if expected_display_name:
assert metadata.arguments["task_display_name"] == expected_display_name
else:
assert "task_display_name" not in metadata.arguments


@pytest.mark.parametrize(
"node_type,node_unique_id,test_indirect_selection,additional_arguments",
[
Expand Down

0 comments on commit 2f5dd93

Please sign in to comment.