diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 627e0d1468c95..c0c9474913482 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -33,7 +33,7 @@ from sqlalchemy import and_, delete, func, not_, or_, select, text, update from sqlalchemy.exc import OperationalError -from sqlalchemy.orm import joinedload, load_only, make_transient, selectinload +from sqlalchemy.orm import joinedload, lazyload, load_only, make_transient, selectinload from sqlalchemy.sql import expression from airflow import settings @@ -1633,13 +1633,10 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = NEW_SESSION) -> int: query = ( select(TI) + .options(lazyload("dag_run")) # avoids double join to dag_run .where(TI.state.in_(State.adoptable_states)) - # outerjoin is because we didn't use to have queued_by_job - # set, so we need to pick up anything pre upgrade. This (and the - # "or queued_by_job_id IS NONE") can go as soon as scheduler HA is - # released. - .outerjoin(TI.queued_by_job) - .where(or_(TI.queued_by_job_id.is_(None), Job.state != JobState.RUNNING)) + .join(TI.queued_by_job) + .where(Job.state.is_distinct_from(JobState.RUNNING)) .join(TI.dag_run) .where( DagRun.run_type != DagRunType.BACKFILL_JOB, diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 8e745e1e083f1..08d994d50c828 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -1622,10 +1622,15 @@ def test_adopt_or_reset_orphaned_tasks(self, dag_maker): start_date=DEFAULT_DATE, session=session, ) + scheduler_job = Job() + session.add(scheduler_job) + session.commit() ti = dr.get_task_instance(task_id=op1.task_id, session=session) ti.state = State.QUEUED + ti.queued_by_job_id = scheduler_job.id ti2 = dr2.get_task_instance(task_id=op1.task_id, session=session) ti2.state = State.QUEUED + ti2.queued_by_job_id = scheduler_job.id session.commit() processor = mock.MagicMock() @@ -1636,6 +1641,7 @@ def test_adopt_or_reset_orphaned_tasks(self, dag_maker): self.job_runner.adopt_or_reset_orphaned_tasks() ti = dr.get_task_instance(task_id=op1.task_id, session=session) + assert ti.state == State.NONE ti2 = dr2.get_task_instance(task_id=op1.task_id, session=session) @@ -3153,19 +3159,21 @@ def test_adopt_or_reset_orphaned_tasks_nothing(self): "adoptable_state", list(sorted(State.adoptable_states)), ) - def test_adopt_or_reset_resettable_tasks(self, dag_maker, adoptable_state): + def test_adopt_or_reset_resettable_tasks(self, dag_maker, adoptable_state, session): dag_id = "test_adopt_or_reset_adoptable_tasks_" + adoptable_state.name with dag_maker(dag_id=dag_id, schedule="@daily"): task_id = dag_id + "_task" EmptyOperator(task_id=task_id) - + old_job = Job() + session.add(old_job) + session.commit() scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) - session = settings.Session() dr1 = dag_maker.create_dagrun(external_trigger=True) ti = dr1.get_task_instances(session=session)[0] ti.state = adoptable_state + ti.queued_by_job_id = old_job.id session.merge(ti) session.merge(dr1) session.commit() @@ -3173,12 +3181,15 @@ def test_adopt_or_reset_resettable_tasks(self, dag_maker, adoptable_state): num_reset_tis = self.job_runner.adopt_or_reset_orphaned_tasks(session=session) assert 1 == num_reset_tis - def test_adopt_or_reset_orphaned_tasks_external_triggered_dag(self, dag_maker): + def test_adopt_or_reset_orphaned_tasks_external_triggered_dag(self, dag_maker, session): dag_id = "test_reset_orphaned_tasks_external_triggered_dag" with dag_maker(dag_id=dag_id, schedule="@daily"): task_id = dag_id + "_task" EmptyOperator(task_id=task_id) + old_job = Job() + session.add(old_job) + session.flush() scheduler_job = Job() self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() @@ -3186,12 +3197,13 @@ def test_adopt_or_reset_orphaned_tasks_external_triggered_dag(self, dag_maker): dr1 = dag_maker.create_dagrun(external_trigger=True) ti = dr1.get_task_instances(session=session)[0] ti.state = State.QUEUED + ti.queued_by_job_id = old_job.id session.merge(ti) session.merge(dr1) session.commit() num_reset_tis = self.job_runner.adopt_or_reset_orphaned_tasks(session=session) - assert 1 == num_reset_tis + assert num_reset_tis == 1 def test_adopt_or_reset_orphaned_tasks_backfill_dag(self, dag_maker): dag_id = "test_adopt_or_reset_orphaned_tasks_backfill_dag" @@ -3224,6 +3236,7 @@ def test_reset_orphaned_tasks_no_orphans(self, dag_maker): EmptyOperator(task_id=task_id) scheduler_job = Job() + scheduler_job.state = "running" self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) session = settings.Session() session.add(scheduler_job) @@ -3237,9 +3250,9 @@ def test_reset_orphaned_tasks_no_orphans(self, dag_maker): session.merge(tis[0]) session.flush() - assert 0 == self.job_runner.adopt_or_reset_orphaned_tasks(session=session) + assert self.job_runner.adopt_or_reset_orphaned_tasks(session=session) == 0 tis[0].refresh_from_db() - assert State.RUNNING == tis[0].state + assert tis[0].state == State.RUNNING def test_reset_orphaned_tasks_non_running_dagruns(self, dag_maker): """Ensure orphaned tasks with non-running dagruns are not reset."""