From fe27031382e2034b59a23db1c6b9bdbfef259137 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Tue, 29 Aug 2023 17:48:43 +0100 Subject: [PATCH] Fix MappedTaskGroup tasks not respecting upstream dependency (#33732) * Fix MappedTaskGroup tasks not respecting upstream dependency When a MappedTaskGroup has upstream dependencies, the tasks in the group don't wait for the upstream tasks before they start running, this causes the tasks to fail. From my investigation, the tasks inside the MappedTaskGroup don't have upstream tasks while the MappedTaskGroup has the upstream tasks properly set. Due to this, the task's dependencies are met even though the Group has upstreams that haven't finished. The Fix was to set upstreams after creating the task group with the factory Closes: https://github.com/apache/airflow/issues/33446 * set the relationship in __exit__ --- airflow/utils/task_group.py | 7 +++-- tests/decorators/test_task_group.py | 46 +++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index e892f0e85dbb3..0b41929f36ba4 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -566,8 +566,6 @@ class MappedTaskGroup(TaskGroup): def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None: super().__init__(**kwargs) self._expand_input = expand_input - for op, _ in expand_input.iter_references(): - self.set_upstream(op) def iter_mapped_dependencies(self) -> Iterator[Operator]: """Upstream dependencies that provide XComs used by this mapped task group.""" @@ -620,6 +618,11 @@ def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: (g._expand_input.get_total_map_length(run_id, session=session) for g in groups), ) + def __exit__(self, exc_type, exc_val, exc_tb): + for op, _ in self._expand_input.iter_references(): + self.set_upstream(op) + super().__exit__(exc_type, exc_val, exc_tb) + class TaskGroupContext: """TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager.""" diff --git a/tests/decorators/test_task_group.py b/tests/decorators/test_task_group.py index 3462c3a1d83a5..4c741ef1c15f1 100644 --- a/tests/decorators/test_task_group.py +++ b/tests/decorators/test_task_group.py @@ -191,6 +191,52 @@ def tg(a, b): assert saved == {"a": 1, "b": MappedArgument(input=tg._expand_input, key="b")} +def test_task_group_expand_kwargs_with_upstream(dag_maker, session, caplog): + with dag_maker() as dag: + + @dag.task + def t1(): + return [{"a": 1}, {"a": 2}] + + @task_group("tg1") + def tg1(a, b): + @dag.task() + def t2(): + return [a, b] + + t2() + + tg1.expand_kwargs(t1()) + + dr = dag_maker.create_dagrun() + dr.task_instance_scheduling_decisions() + assert "Cannot expand" not in caplog.text + assert "missing upstream values: ['expand_kwargs() argument']" not in caplog.text + + +def test_task_group_expand_with_upstream(dag_maker, session, caplog): + with dag_maker() as dag: + + @dag.task + def t1(): + return [1, 2, 3] + + @task_group("tg1") + def tg1(a, b): + @dag.task() + def t2(): + return [a, b] + + t2() + + tg1.partial(a=1).expand(b=t1()) + + dr = dag_maker.create_dagrun() + dr.task_instance_scheduling_decisions() + assert "Cannot expand" not in caplog.text + assert "missing upstream values: ['b']" not in caplog.text + + def test_override_dag_default_args(): @dag( dag_id="test_dag",