Skip to content

Commit

Permalink
Fix MappedTaskGroup tasks not respecting upstream dependency (apache#…
Browse files Browse the repository at this point in the history
…33732)

* Fix MappedTaskGroup tasks not respecting upstream dependency

When a MappedTaskGroup has upstream dependencies, the tasks in the group don't wait for the upstream tasks
before they start running, this causes the tasks to fail.
From my investigation, the tasks inside the MappedTaskGroup don't have upstream tasks while the
MappedTaskGroup has the upstream tasks properly set. Due to this, the task's dependencies are met even though the Group has
upstreams that haven't finished.
The Fix was to set upstreams after creating the task group with the factory
Closes: apache#33446

* set the relationship in __exit__
  • Loading branch information
ephraimbuddy authored Aug 29, 2023
1 parent 869f84e commit fe27031
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
7 changes: 5 additions & 2 deletions airflow/utils/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,6 @@ class MappedTaskGroup(TaskGroup):
def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._expand_input = expand_input
for op, _ in expand_input.iter_references():
self.set_upstream(op)

def iter_mapped_dependencies(self) -> Iterator[Operator]:
"""Upstream dependencies that provide XComs used by this mapped task group."""
Expand Down Expand Up @@ -620,6 +618,11 @@ def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
(g._expand_input.get_total_map_length(run_id, session=session) for g in groups),
)

def __exit__(self, exc_type, exc_val, exc_tb):
for op, _ in self._expand_input.iter_references():
self.set_upstream(op)
super().__exit__(exc_type, exc_val, exc_tb)


class TaskGroupContext:
"""TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager."""
Expand Down
46 changes: 46 additions & 0 deletions tests/decorators/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,52 @@ def tg(a, b):
assert saved == {"a": 1, "b": MappedArgument(input=tg._expand_input, key="b")}


def test_task_group_expand_kwargs_with_upstream(dag_maker, session, caplog):
with dag_maker() as dag:

@dag.task
def t1():
return [{"a": 1}, {"a": 2}]

@task_group("tg1")
def tg1(a, b):
@dag.task()
def t2():
return [a, b]

t2()

tg1.expand_kwargs(t1())

dr = dag_maker.create_dagrun()
dr.task_instance_scheduling_decisions()
assert "Cannot expand" not in caplog.text
assert "missing upstream values: ['expand_kwargs() argument']" not in caplog.text


def test_task_group_expand_with_upstream(dag_maker, session, caplog):
with dag_maker() as dag:

@dag.task
def t1():
return [1, 2, 3]

@task_group("tg1")
def tg1(a, b):
@dag.task()
def t2():
return [a, b]

t2()

tg1.partial(a=1).expand(b=t1())

dr = dag_maker.create_dagrun()
dr.task_instance_scheduling_decisions()
assert "Cannot expand" not in caplog.text
assert "missing upstream values: ['b']" not in caplog.text


def test_override_dag_default_args():
@dag(
dag_id="test_dag",
Expand Down

0 comments on commit fe27031

Please sign in to comment.