diff --git a/.github/boring-cyborg.yml b/.github/boring-cyborg.yml index 6c6627ecd5a71..71128dafa4d22 100644 --- a/.github/boring-cyborg.yml +++ b/.github/boring-cyborg.yml @@ -652,12 +652,12 @@ labelPRBasedOnFilePath: - airflow/cli/commands/triggerer_command.py - airflow/jobs/triggerer_job_runner.py - airflow/models/trigger.py - - airflow/triggers/**/* + - providers/src/airflow/providers/standard/triggers/**/* - tests/cli/commands/test_triggerer_command.py - tests/jobs/test_triggerer_job.py - tests/models/test_trigger.py - tests/jobs/test_triggerer_job_logging.py - - tests/triggers/**/* + - providers/tests/standard/triggers/**/* area:Serialization: - airflow/serialization/**/* diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py index 0b5793f06ced7..ff43ff2f463f7 100644 --- a/airflow/sensors/external_task.py +++ b/airflow/sensors/external_task.py @@ -30,11 +30,11 @@ from airflow.models.dagbag import DagBag from airflow.models.taskinstance import TaskInstance from airflow.operators.empty import EmptyOperator +from airflow.providers.standard.triggers.external_task import WorkflowTrigger +from airflow.providers.standard.utils.sensor_helper import _get_count, _get_external_task_group_task_ids from airflow.sensors.base import BaseSensorOperator -from airflow.triggers.external_task import WorkflowTrigger from airflow.utils.file import correct_maybe_zipped from airflow.utils.helpers import build_airflow_url_with_query -from airflow.utils.sensor_helper import _get_count, _get_external_task_group_task_ids from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import State, TaskInstanceState diff --git a/dev/breeze/tests/test_pytest_args_for_test_types.py b/dev/breeze/tests/test_pytest_args_for_test_types.py index 740f8c5b3e53f..86a72a4de3838 100644 --- a/dev/breeze/tests/test_pytest_args_for_test_types.py +++ b/dev/breeze/tests/test_pytest_args_for_test_types.py @@ -130,7 +130,6 @@ "tests/template", "tests/testconfig", "tests/timetables", - "tests/triggers", ], ), ( diff --git a/docs/apache-airflow/authoring-and-scheduling/deferring.rst b/docs/apache-airflow/authoring-and-scheduling/deferring.rst index 0b477151a9091..c208e8cedcbc9 100644 --- a/docs/apache-airflow/authoring-and-scheduling/deferring.rst +++ b/docs/apache-airflow/authoring-and-scheduling/deferring.rst @@ -68,7 +68,7 @@ When writing a deferrable operators these are the main points to consider: from airflow.configuration import conf from airflow.sensors.base import BaseSensorOperator - from airflow.triggers.temporal import TimeDeltaTrigger + from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger from airflow.utils.context import Context @@ -122,7 +122,7 @@ This example shows the structure of a basic trigger, a very simplified version o self.moment = moment def serialize(self): - return ("airflow.triggers.temporal.DateTimeTrigger", {"moment": self.moment}) + return ("airflow.providers.standard.triggers.temporal.DateTimeTrigger", {"moment": self.moment}) async def run(self): while self.moment > timezone.utcnow(): @@ -177,7 +177,7 @@ Here's a basic example of how a sensor might trigger deferral: from typing import TYPE_CHECKING, Any from airflow.sensors.base import BaseSensorOperator - from airflow.triggers.temporal import TimeDeltaTrigger + from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -237,7 +237,7 @@ In the sensor part, we'll need to provide the path to ``TimeDeltaTrigger`` as `` class WaitOneHourSensor(BaseSensorOperator): start_trigger_args = StartTriggerArgs( - trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger", + trigger_cls="airflow.providers.standard.triggers.temporal.TimeDeltaTrigger", trigger_kwargs={"moment": timedelta(hours=1)}, next_method="execute_complete", next_kwargs=None, @@ -268,7 +268,7 @@ In the sensor part, we'll need to provide the path to ``TimeDeltaTrigger`` as `` class WaitHoursSensor(BaseSensorOperator): start_trigger_args = StartTriggerArgs( - trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger", + trigger_cls="airflow.providers.standard.triggers.temporal.TimeDeltaTrigger", trigger_kwargs={"moment": timedelta(hours=1)}, next_method="execute_complete", next_kwargs=None, @@ -307,7 +307,7 @@ After the trigger has finished executing, the task may be sent back to the worke class WaitHoursSensor(BaseSensorOperator): start_trigger_args = StartTriggerArgs( - trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger", + trigger_cls="airflow.providers.standard.triggers.temporal.TimeDeltaTrigger", trigger_kwargs={"moment": timedelta(hours=1)}, next_method="execute_complete", next_kwargs=None, diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 24745db0f4002..5de68bdce55bc 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -846,6 +846,7 @@ "plugins": [], "cross-providers-deps": [ "amazon", + "common.compat", "google", "oracle", "sftp" diff --git a/providers/src/airflow/providers/common/compat/standard/triggers.py b/providers/src/airflow/providers/common/compat/standard/triggers.py new file mode 100644 index 0000000000000..1f7f524e8867c --- /dev/null +++ b/providers/src/airflow/providers/common/compat/standard/triggers.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger +else: + try: + from airflow.providers.standard.triggers.temporal import TimeDeltaTrigger + except ModuleNotFoundError: + from airflow.triggers.temporal import TimeDeltaTrigger + + +__all__ = ["TimeDeltaTrigger"] diff --git a/providers/src/airflow/providers/microsoft/azure/sensors/msgraph.py b/providers/src/airflow/providers/microsoft/azure/sensors/msgraph.py index 904e5241c22b6..42c5852900567 100644 --- a/providers/src/airflow/providers/microsoft/azure/sensors/msgraph.py +++ b/providers/src/airflow/providers/microsoft/azure/sensors/msgraph.py @@ -21,10 +21,10 @@ from typing import TYPE_CHECKING, Any, Callable from airflow.exceptions import AirflowException +from airflow.providers.common.compat.standard.triggers import TimeDeltaTrigger from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook from airflow.providers.microsoft.azure.triggers.msgraph import MSGraphTrigger, ResponseSerializer from airflow.sensors.base import BaseSensorOperator -from airflow.triggers.temporal import TimeDeltaTrigger if TYPE_CHECKING: from datetime import timedelta diff --git a/providers/src/airflow/providers/standard/operators/trigger_dagrun.py b/providers/src/airflow/providers/standard/operators/trigger_dagrun.py index 0cd3b8861a54b..fcc066778330b 100644 --- a/providers/src/airflow/providers/standard/operators/trigger_dagrun.py +++ b/providers/src/airflow/providers/standard/operators/trigger_dagrun.py @@ -40,7 +40,7 @@ from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun from airflow.models.xcom import XCom -from airflow.triggers.external_task import DagStateTrigger +from airflow.providers.standard.triggers.external_task import DagStateTrigger from airflow.utils import timezone from airflow.utils.helpers import build_airflow_url_with_query from airflow.utils.session import provide_session diff --git a/providers/src/airflow/providers/standard/provider.yaml b/providers/src/airflow/providers/standard/provider.yaml index 5d0c02aaa0398..eea8991e2526e 100644 --- a/providers/src/airflow/providers/standard/provider.yaml +++ b/providers/src/airflow/providers/standard/provider.yaml @@ -68,6 +68,13 @@ hooks: - airflow.providers.standard.hooks.package_index - airflow.providers.standard.hooks.subprocess +triggers: + - integration-name: Standard + python-modules: + - airflow.providers.standard.triggers.external_task + - airflow.providers.standard.triggers.file + - airflow.providers.standard.triggers.temporal + config: standard: description: Options for the standard provider operators. diff --git a/providers/src/airflow/providers/standard/sensors/date_time.py b/providers/src/airflow/providers/standard/sensors/date_time.py index 44e3b44cae76d..65ca95da5cc84 100644 --- a/providers/src/airflow/providers/standard/sensors/date_time.py +++ b/providers/src/airflow/providers/standard/sensors/date_time.py @@ -22,6 +22,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, NoReturn +from airflow.providers.standard.triggers.temporal import DateTimeTrigger from airflow.providers.standard.utils.version_references import AIRFLOW_V_3_0_PLUS from airflow.sensors.base import BaseSensorOperator @@ -40,7 +41,6 @@ class StartTriggerArgs: # type: ignore[no-redef] timeout: datetime.timedelta | None = None -from airflow.triggers.temporal import DateTimeTrigger from airflow.utils import timezone if TYPE_CHECKING: @@ -111,7 +111,7 @@ class DateTimeSensorAsync(DateTimeSensor): """ start_trigger_args = StartTriggerArgs( - trigger_cls="airflow.triggers.temporal.DateTimeTrigger", + trigger_cls="airflow.providers.standard.triggers.temporal.DateTimeTrigger", trigger_kwargs={"moment": "", "end_from_trigger": False}, next_method="execute_complete", next_kwargs=None, diff --git a/providers/src/airflow/providers/standard/sensors/filesystem.py b/providers/src/airflow/providers/standard/sensors/filesystem.py index 0a1c2d46bebff..6bf7452e59666 100644 --- a/providers/src/airflow/providers/standard/sensors/filesystem.py +++ b/providers/src/airflow/providers/standard/sensors/filesystem.py @@ -27,9 +27,9 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.providers.standard.hooks.filesystem import FSHook +from airflow.providers.standard.triggers.file import FileTrigger from airflow.sensors.base import BaseSensorOperator from airflow.triggers.base import StartTriggerArgs -from airflow.triggers.file import FileTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -64,7 +64,7 @@ class FileSensor(BaseSensorOperator): template_fields: Sequence[str] = ("filepath",) ui_color = "#91818a" start_trigger_args = StartTriggerArgs( - trigger_cls="airflow.triggers.file.FileTrigger", + trigger_cls="airflow.providers.standard.triggers.file.FileTrigger", trigger_kwargs={}, next_method="execute_complete", next_kwargs=None, diff --git a/providers/src/airflow/providers/standard/sensors/time.py b/providers/src/airflow/providers/standard/sensors/time.py index 8b727cb1cf1dd..6443f2a344a16 100644 --- a/providers/src/airflow/providers/standard/sensors/time.py +++ b/providers/src/airflow/providers/standard/sensors/time.py @@ -21,6 +21,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, NoReturn +from airflow.providers.standard.triggers.temporal import DateTimeTrigger from airflow.providers.standard.utils.version_references import AIRFLOW_V_2_10_PLUS from airflow.sensors.base import BaseSensorOperator @@ -39,7 +40,6 @@ class StartTriggerArgs: # type: ignore[no-redef] timeout: datetime.timedelta | None = None -from airflow.triggers.temporal import DateTimeTrigger from airflow.utils import timezone if TYPE_CHECKING: @@ -85,7 +85,7 @@ class TimeSensorAsync(BaseSensorOperator): """ start_trigger_args = StartTriggerArgs( - trigger_cls="airflow.triggers.temporal.DateTimeTrigger", + trigger_cls="airflow.providers.standard.triggers.temporal.DateTimeTrigger", trigger_kwargs={"moment": "", "end_from_trigger": False}, next_method="execute_complete", next_kwargs=None, diff --git a/providers/src/airflow/providers/standard/sensors/time_delta.py b/providers/src/airflow/providers/standard/sensors/time_delta.py index 0b50c5cef8630..a0d3189b027fd 100644 --- a/providers/src/airflow/providers/standard/sensors/time_delta.py +++ b/providers/src/airflow/providers/standard/sensors/time_delta.py @@ -23,9 +23,9 @@ from airflow.configuration import conf from airflow.exceptions import AirflowSkipException +from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger from airflow.providers.standard.utils.version_references import AIRFLOW_V_3_0_PLUS from airflow.sensors.base import BaseSensorOperator -from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger from airflow.utils import timezone if TYPE_CHECKING: diff --git a/tests/triggers/__init__.py b/providers/src/airflow/providers/standard/triggers/__init__.py similarity index 100% rename from tests/triggers/__init__.py rename to providers/src/airflow/providers/standard/triggers/__init__.py diff --git a/airflow/triggers/external_task.py b/providers/src/airflow/providers/standard/triggers/external_task.py similarity index 80% rename from airflow/triggers/external_task.py rename to providers/src/airflow/providers/standard/triggers/external_task.py index 159a6df909501..ff99caf668d43 100644 --- a/airflow/triggers/external_task.py +++ b/providers/src/airflow/providers/standard/triggers/external_task.py @@ -24,8 +24,9 @@ from sqlalchemy import func from airflow.models import DagRun +from airflow.providers.standard.utils.sensor_helper import _get_count +from airflow.providers.standard.utils.version_references import AIRFLOW_V_3_0_PLUS from airflow.triggers.base import BaseTrigger, TriggerEvent -from airflow.utils.sensor_helper import _get_count from airflow.utils.session import NEW_SESSION, provide_session if typing.TYPE_CHECKING: @@ -54,7 +55,8 @@ class WorkflowTrigger(BaseTrigger): def __init__( self, external_dag_id: str, - logical_dates: list, + logical_dates: list[datetime] | None = None, + execution_dates: list[datetime] | None = None, external_task_ids: typing.Collection[str] | None = None, external_task_group_id: str | None = None, failed_states: typing.Iterable[str] | None = None, @@ -73,12 +75,18 @@ def __init__( self.logical_dates = logical_dates self.poke_interval = poke_interval self.soft_fail = soft_fail + self.execution_dates = execution_dates super().__init__(**kwargs) def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize the trigger param and module path.""" + _dates = ( + {"logical_dates": self.logical_dates} + if AIRFLOW_V_3_0_PLUS + else {"execution_dates": self.execution_dates} + ) return ( - "airflow.triggers.external_task.WorkflowTrigger", + "airflow.providers.standard.triggers.external_task.WorkflowTrigger", { "external_dag_id": self.external_dag_id, "external_task_ids": self.external_task_ids, @@ -86,7 +94,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "failed_states": self.failed_states, "skipped_states": self.skipped_states, "allowed_states": self.allowed_states, - "logical_dates": self.logical_dates, + **_dates, "poke_interval": self.poke_interval, "soft_fail": self.soft_fail, }, @@ -109,7 +117,8 @@ async def run(self) -> typing.AsyncIterator[TriggerEvent]: yield TriggerEvent({"status": "skipped"}) return allowed_count = await self._get_count(self.allowed_states) - if allowed_count == len(self.logical_dates): + _dates = self.logical_dates if AIRFLOW_V_3_0_PLUS else self.execution_dates + if allowed_count == len(_dates): # type: ignore[arg-type] yield TriggerEvent({"status": "success"}) return self.log.info("Sleeping for %s seconds", self.poke_interval) @@ -124,7 +133,7 @@ def _get_count(self, states: typing.Iterable[str] | None) -> int: :return The count of records. """ return _get_count( - dttm_filter=self.logical_dates, + dttm_filter=self.logical_dates if AIRFLOW_V_3_0_PLUS else self.execution_dates, external_task_ids=self.external_task_ids, external_task_group_id=self.external_task_group_id, external_dag_id=self.external_dag_id, @@ -147,23 +156,30 @@ def __init__( self, dag_id: str, states: list[DagRunState], - logical_dates: list[datetime], + logical_dates: list[datetime] | None = None, + execution_dates: list[datetime] | None = None, poll_interval: float = 5.0, ): super().__init__() self.dag_id = dag_id self.states = states self.logical_dates = logical_dates + self.execution_dates = execution_dates self.poll_interval = poll_interval def serialize(self) -> tuple[str, dict[str, typing.Any]]: """Serialize DagStateTrigger arguments and classpath.""" + _dates = ( + {"logical_dates": self.logical_dates} + if AIRFLOW_V_3_0_PLUS + else {"execution_dates": self.execution_dates} + ) return ( - "airflow.triggers.external_task.DagStateTrigger", + "airflow.providers.standard.triggers.external_task.DagStateTrigger", { "dag_id": self.dag_id, "states": self.states, - "logical_dates": self.logical_dates, + **_dates, "poll_interval": self.poll_interval, }, ) @@ -173,7 +189,8 @@ async def run(self) -> typing.AsyncIterator[TriggerEvent]: while True: # mypy confuses typing here num_dags = await self.count_dags() # type: ignore[call-arg] - if num_dags == len(self.logical_dates): + _dates = self.logical_dates if AIRFLOW_V_3_0_PLUS else self.execution_dates + if num_dags == len(_dates): # type: ignore[arg-type] yield TriggerEvent(self.serialize()) return await asyncio.sleep(self.poll_interval) @@ -182,12 +199,17 @@ async def run(self) -> typing.AsyncIterator[TriggerEvent]: @provide_session def count_dags(self, *, session: Session = NEW_SESSION) -> int | None: """Count how many dag runs in the database match our criteria.""" + _dag_run_date_condition = ( + DagRun.logical_date.in_(self.logical_dates) + if AIRFLOW_V_3_0_PLUS + else DagRun.execution_date.in_(self.execution_dates) + ) count = ( session.query(func.count("*")) # .count() is inefficient .filter( DagRun.dag_id == self.dag_id, DagRun.state.in_(self.states), - DagRun.logical_date.in_(self.logical_dates), + _dag_run_date_condition, ) .scalar() ) diff --git a/airflow/triggers/file.py b/providers/src/airflow/providers/standard/triggers/file.py similarity index 85% rename from airflow/triggers/file.py rename to providers/src/airflow/providers/standard/triggers/file.py index 5f40dd4d2de3f..f6a7715a035eb 100644 --- a/airflow/triggers/file.py +++ b/providers/src/airflow/providers/standard/triggers/file.py @@ -20,7 +20,6 @@ import datetime import os import typing -import warnings from glob import glob from typing import Any @@ -48,21 +47,12 @@ def __init__( super().__init__() self.filepath = filepath self.recursive = recursive - if kwargs.get("poll_interval") is not None: - warnings.warn( - "`poll_interval` has been deprecated and will be removed in future." - "Please use `poke_interval` instead.", - DeprecationWarning, - stacklevel=2, - ) - self.poke_interval: float = kwargs["poll_interval"] - else: - self.poke_interval = poke_interval + self.poke_interval = poke_interval def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize FileTrigger arguments and classpath.""" return ( - "airflow.triggers.file.FileTrigger", + "airflow.providers.standard.triggers.file.FileTrigger", { "filepath": self.filepath, "recursive": self.recursive, diff --git a/airflow/triggers/temporal.py b/providers/src/airflow/providers/standard/triggers/temporal.py similarity index 89% rename from airflow/triggers/temporal.py rename to providers/src/airflow/providers/standard/triggers/temporal.py index e5a3ca1be00df..032af6186d2b3 100644 --- a/airflow/triggers/temporal.py +++ b/providers/src/airflow/providers/standard/triggers/temporal.py @@ -23,9 +23,14 @@ import pendulum -from airflow.triggers.base import BaseTrigger, TaskSuccessEvent, TriggerEvent +from airflow.exceptions import AirflowException +from airflow.providers.standard.utils.version_references import AIRFLOW_V_2_10_PLUS +from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.utils import timezone +if AIRFLOW_V_2_10_PLUS: + from airflow.triggers.base import TaskSuccessEvent + class DateTimeTrigger(BaseTrigger): """ @@ -50,11 +55,14 @@ def __init__(self, moment: datetime.datetime, *, end_from_trigger: bool = False) raise ValueError("You cannot pass naive datetimes") else: self.moment: pendulum.DateTime = timezone.convert_to_utc(moment) + if not AIRFLOW_V_2_10_PLUS and end_from_trigger: + raise AirflowException("end_from_trigger is only supported in Airflow 2.10 and later. ") + self.end_from_trigger = end_from_trigger def serialize(self) -> tuple[str, dict[str, Any]]: return ( - "airflow.triggers.temporal.DateTimeTrigger", + "airflow.providers.standard.triggers.temporal.DateTimeTrigger", {"moment": self.moment, "end_from_trigger": self.end_from_trigger}, ) diff --git a/airflow/utils/sensor_helper.py b/providers/src/airflow/providers/standard/utils/sensor_helper.py similarity index 100% rename from airflow/utils/sensor_helper.py rename to providers/src/airflow/providers/standard/utils/sensor_helper.py diff --git a/providers/tests/standard/sensors/test_time.py b/providers/tests/standard/sensors/test_time.py index a144c3dc41de7..017b410eda4df 100644 --- a/providers/tests/standard/sensors/test_time.py +++ b/providers/tests/standard/sensors/test_time.py @@ -26,7 +26,7 @@ from airflow.exceptions import TaskDeferred from airflow.models.dag import DAG from airflow.providers.standard.sensors.time import TimeSensor, TimeSensorAsync -from airflow.triggers.temporal import DateTimeTrigger +from airflow.providers.standard.triggers.temporal import DateTimeTrigger from airflow.utils import timezone DEFAULT_TIMEZONE = "Asia/Singapore" # UTC+08:00 diff --git a/providers/tests/standard/triggers/__init__.py b/providers/tests/standard/triggers/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/tests/standard/triggers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/triggers/test_external_task.py b/providers/tests/standard/triggers/test_external_task.py similarity index 86% rename from tests/triggers/test_external_task.py rename to providers/tests/standard/triggers/test_external_task.py index 4a193c5fef9c5..debba6b310bdd 100644 --- a/tests/triggers/test_external_task.py +++ b/providers/tests/standard/triggers/test_external_task.py @@ -24,12 +24,21 @@ from airflow.models.dag import DAG from airflow.models.dagrun import DagRun +from airflow.providers.standard.triggers.external_task import DagStateTrigger, WorkflowTrigger from airflow.triggers.base import TriggerEvent -from airflow.triggers.external_task import DagStateTrigger, WorkflowTrigger from airflow.utils import timezone from airflow.utils.state import DagRunState +from tests_common.test_utils.compat import AIRFLOW_V_2_9_PLUS, AIRFLOW_V_3_0_PLUS +_DATES = ( + {"logical_dates": [timezone.datetime(2022, 1, 1)]} + if AIRFLOW_V_3_0_PLUS + else {"execution_dates": [timezone.datetime(2022, 1, 1)]} +) + + +@pytest.mark.skipif(not AIRFLOW_V_2_9_PLUS, reason="Test requires Airflow 2.9+") class TestWorkflowTrigger: DAG_ID = "external_task" TASK_ID = "external_task_op" @@ -37,14 +46,15 @@ class TestWorkflowTrigger: STATES = ["success", "fail"] @pytest.mark.flaky(reruns=5) - @mock.patch("airflow.triggers.external_task._get_count") + @mock.patch("airflow.providers.standard.triggers.external_task._get_count") @pytest.mark.asyncio async def test_task_workflow_trigger_success(self, mock_get_count): """check the db count get called correctly.""" mock_get_count.side_effect = mocked_get_count + trigger = WorkflowTrigger( external_dag_id=self.DAG_ID, - logical_dates=[timezone.datetime(2022, 1, 1)], + **_DATES, external_task_ids=[self.TASK_ID], allowed_states=self.STATES, poke_interval=0.2, @@ -70,13 +80,14 @@ async def test_task_workflow_trigger_success(self, mock_get_count): await gen.__anext__() @pytest.mark.flaky(reruns=5) - @mock.patch("airflow.triggers.external_task._get_count") + @mock.patch("airflow.providers.standard.triggers.external_task._get_count") @pytest.mark.asyncio async def test_task_workflow_trigger_failed(self, mock_get_count): mock_get_count.side_effect = mocked_get_count + trigger = WorkflowTrigger( external_dag_id=self.DAG_ID, - logical_dates=[timezone.datetime(2022, 1, 1)], + **_DATES, external_task_ids=[self.TASK_ID], failed_states=self.STATES, poke_interval=0.2, @@ -102,13 +113,14 @@ async def test_task_workflow_trigger_failed(self, mock_get_count): with pytest.raises(StopAsyncIteration): await gen.__anext__() - @mock.patch("airflow.triggers.external_task._get_count") + @mock.patch("airflow.providers.standard.triggers.external_task._get_count") @pytest.mark.asyncio async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count): mock_get_count.return_value = 0 + trigger = WorkflowTrigger( external_dag_id=self.DAG_ID, - logical_dates=[timezone.datetime(2022, 1, 1)], + **_DATES, external_task_ids=[self.TASK_ID], failed_states=self.STATES, poke_interval=0.2, @@ -133,13 +145,14 @@ async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count): await gen.__anext__() @pytest.mark.flaky(reruns=5) - @mock.patch("airflow.triggers.external_task._get_count") + @mock.patch("airflow.providers.standard.triggers.external_task._get_count") @pytest.mark.asyncio async def test_task_workflow_trigger_skipped(self, mock_get_count): mock_get_count.side_effect = mocked_get_count + trigger = WorkflowTrigger( external_dag_id=self.DAG_ID, - logical_dates=[timezone.datetime(2022, 1, 1)], + **_DATES, external_task_ids=[self.TASK_ID], skipped_states=self.STATES, poke_interval=0.2, @@ -162,14 +175,15 @@ async def test_task_workflow_trigger_skipped(self, mock_get_count): states=["success", "fail"], ) - @mock.patch("airflow.triggers.external_task._get_count") + @mock.patch("airflow.providers.standard.triggers.external_task._get_count") @mock.patch("asyncio.sleep") @pytest.mark.asyncio async def test_task_workflow_trigger_sleep_success(self, mock_sleep, mock_get_count): mock_get_count.side_effect = [0, 1] + trigger = WorkflowTrigger( external_dag_id=self.DAG_ID, - logical_dates=[timezone.datetime(2022, 1, 1)], + **_DATES, external_task_ids=[self.TASK_ID], poke_interval=0.2, ) @@ -197,16 +211,16 @@ def test_serialization(self): """ trigger = WorkflowTrigger( external_dag_id=self.DAG_ID, - logical_dates=[timezone.datetime(2022, 1, 1)], + **_DATES, external_task_ids=[self.TASK_ID], allowed_states=self.STATES, poke_interval=5, ) classpath, kwargs = trigger.serialize() - assert classpath == "airflow.triggers.external_task.WorkflowTrigger" + assert classpath == "airflow.providers.standard.triggers.external_task.WorkflowTrigger" assert kwargs == { "external_dag_id": self.DAG_ID, - "logical_dates": [timezone.datetime(2022, 1, 1)], + **_DATES, "external_task_ids": [self.TASK_ID], "external_task_group_id": None, "failed_states": None, @@ -231,10 +245,15 @@ async def test_dag_state_trigger(self, session): reaches an allowed state (i.e. SUCCESS). """ dag = DAG(self.DAG_ID, schedule=None, start_date=timezone.datetime(2022, 1, 1)) + logical_date_or_execution_date = ( + {"logical_date": timezone.datetime(2022, 1, 1)} + if AIRFLOW_V_3_0_PLUS + else {"execution_date": timezone.datetime(2022, 1, 1)} + ) dag_run = DagRun( dag_id=dag.dag_id, run_type="manual", - logical_date=timezone.datetime(2022, 1, 1), + **logical_date_or_execution_date, run_id=self.RUN_ID, ) session.add(dag_run) @@ -243,7 +262,7 @@ async def test_dag_state_trigger(self, session): trigger = DagStateTrigger( dag_id=dag.dag_id, states=self.STATES, - logical_dates=[timezone.datetime(2022, 1, 1)], + **_DATES, poll_interval=0.2, ) @@ -267,15 +286,15 @@ def test_serialization(self): trigger = DagStateTrigger( dag_id=self.DAG_ID, states=self.STATES, - logical_dates=[timezone.datetime(2022, 1, 1)], + **_DATES, poll_interval=5, ) classpath, kwargs = trigger.serialize() - assert classpath == "airflow.triggers.external_task.DagStateTrigger" + assert classpath == "airflow.providers.standard.triggers.external_task.DagStateTrigger" assert kwargs == { "dag_id": self.DAG_ID, "states": self.STATES, - "logical_dates": [timezone.datetime(2022, 1, 1)], + **_DATES, "poll_interval": 5, } diff --git a/tests/triggers/test_file.py b/providers/tests/standard/triggers/test_file.py similarity index 91% rename from tests/triggers/test_file.py rename to providers/tests/standard/triggers/test_file.py index 6fb25dea3f00c..baf0dffa80d0a 100644 --- a/tests/triggers/test_file.py +++ b/providers/tests/standard/triggers/test_file.py @@ -20,7 +20,7 @@ import pytest -from airflow.triggers.file import FileTrigger +from airflow.providers.standard.triggers.file import FileTrigger class TestFileTrigger: @@ -30,7 +30,7 @@ def test_serialization(self): """Asserts that the trigger correctly serializes its arguments and classpath.""" trigger = FileTrigger(filepath=self.FILE_PATH, poll_interval=5) classpath, kwargs = trigger.serialize() - assert classpath == "airflow.triggers.file.FileTrigger" + assert classpath == "airflow.providers.standard.triggers.file.FileTrigger" assert kwargs == { "filepath": self.FILE_PATH, "poke_interval": 5, @@ -46,7 +46,7 @@ async def test_task_file_trigger(self, tmp_path): trigger = FileTrigger( filepath=str(p.resolve()), - poll_interval=0.2, + poke_interval=0.2, ) task = asyncio.create_task(trigger.run().__anext__()) diff --git a/tests/triggers/test_temporal.py b/providers/tests/standard/triggers/test_temporal.py similarity index 71% rename from tests/triggers/test_temporal.py rename to providers/tests/standard/triggers/test_temporal.py index 90f00a694e5b3..7271d43f5d086 100644 --- a/tests/triggers/test_temporal.py +++ b/providers/tests/standard/triggers/test_temporal.py @@ -23,8 +23,9 @@ import pendulum import pytest +from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger +from airflow.providers.standard.utils.version_references import AIRFLOW_V_2_10_PLUS from airflow.triggers.base import TriggerEvent -from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger from airflow.utils import timezone from airflow.utils.state import TaskInstanceState from airflow.utils.timezone import utcnow @@ -56,7 +57,7 @@ def test_datetime_trigger_serialization(): moment = pendulum.instance(datetime.datetime(2020, 4, 1, 13, 0), pendulum.UTC) trigger = DateTimeTrigger(moment) classpath, kwargs = trigger.serialize() - assert classpath == "airflow.triggers.temporal.DateTimeTrigger" + assert classpath == "airflow.providers.standard.triggers.temporal.DateTimeTrigger" assert kwargs == {"moment": moment, "end_from_trigger": False} @@ -68,12 +69,13 @@ def test_timedelta_trigger_serialization(): trigger = TimeDeltaTrigger(datetime.timedelta(seconds=10)) expected_moment = timezone.utcnow() + datetime.timedelta(seconds=10) classpath, kwargs = trigger.serialize() - assert classpath == "airflow.triggers.temporal.DateTimeTrigger" + assert classpath == "airflow.providers.standard.triggers.temporal.DateTimeTrigger" # We need to allow for a little time difference to avoid this test being # flaky if it runs over the boundary of a single second assert -2 < (kwargs["moment"] - expected_moment).total_seconds() < 2 +@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Only for Airflow 2.10+") @pytest.mark.parametrize( "tz, end_from_trigger", [ @@ -84,7 +86,7 @@ def test_timedelta_trigger_serialization(): ], ) @pytest.mark.asyncio -async def test_datetime_trigger_timing(tz, end_from_trigger): +async def test_datetime_trigger_timing_airflow_2_10_plus(tz, end_from_trigger): """ Tests that the DateTimeTrigger only goes off on or after the appropriate time. @@ -113,8 +115,46 @@ async def test_datetime_trigger_timing(tz, end_from_trigger): assert result.payload == expected_payload -@mock.patch("airflow.triggers.temporal.timezone.utcnow") -@mock.patch("airflow.triggers.temporal.asyncio.sleep") +@pytest.mark.skipif(AIRFLOW_V_2_10_PLUS, reason="Only for Airflow < 2.10+") +@pytest.mark.parametrize( + "tz", + [ + timezone.parse_timezone("UTC"), + timezone.parse_timezone("Europe/Paris"), + timezone.parse_timezone("America/Toronto"), + ], +) +@pytest.mark.asyncio +async def test_datetime_trigger_timing(tz): + """ + Tests that the DateTimeTrigger only goes off on or after the appropriate + time. + """ + past_moment = pendulum.instance((timezone.utcnow() - datetime.timedelta(seconds=60)).astimezone(tz)) + future_moment = pendulum.instance((timezone.utcnow() + datetime.timedelta(seconds=60)).astimezone(tz)) + + # Create a task that runs the trigger for a short time then cancels it + trigger = DateTimeTrigger(future_moment) + trigger_task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + # It should not have produced a result + assert trigger_task.done() is False + trigger_task.cancel() + + # Now, make one waiting for en event in the past and do it again + trigger = DateTimeTrigger(past_moment) + trigger_task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + assert trigger_task.done() is True + result = trigger_task.result() + assert isinstance(result, TriggerEvent) + assert result.payload == past_moment + + +@mock.patch("airflow.providers.standard.triggers.temporal.timezone.utcnow") +@mock.patch("airflow.providers.standard.triggers.temporal.asyncio.sleep") @pytest.mark.asyncio async def test_datetime_trigger_mocked(mock_sleep, mock_utcnow): """ diff --git a/scripts/cov/other_coverage.py b/scripts/cov/other_coverage.py index dae7733ec5c15..41c7a4352369d 100644 --- a/scripts/cov/other_coverage.py +++ b/scripts/cov/other_coverage.py @@ -66,7 +66,6 @@ other_tests = [ "tests/dag_processing", "tests/jobs", - "tests/triggers", ] """ @@ -96,7 +95,6 @@ "tests/template", "tests/testconfig", "tests/timetables", - "tests/triggers", """ diff --git a/tests/cli/commands/test_dag_command.py b/tests/cli/commands/test_dag_command.py index 7b9dd1e63bd18..aaff3f8730e9d 100644 --- a/tests/cli/commands/test_dag_command.py +++ b/tests/cli/commands/test_dag_command.py @@ -40,8 +40,8 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.dag import _run_inline_trigger from airflow.models.serialized_dag import SerializedDagModel +from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger from airflow.triggers.base import TriggerEvent -from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import DagRunState diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py index 98e63bcd40133..25ef1507fbdc5 100644 --- a/tests/cli/commands/test_task_command.py +++ b/tests/cli/commands/test_task_command.py @@ -419,9 +419,9 @@ def test_cli_test_with_env_vars(self): assert "foo=bar" in output assert "AIRFLOW_TEST_MODE=True" in output - @mock.patch("airflow.triggers.file.os.path.getmtime", return_value=0) - @mock.patch("airflow.triggers.file.glob", return_value=["/tmp/test"]) - @mock.patch("airflow.triggers.file.os.path.isfile", return_value=True) + @mock.patch("airflow.providers.standard.triggers.file.os.path.getmtime", return_value=0) + @mock.patch("airflow.providers.standard.triggers.file.glob", return_value=["/tmp/test"]) + @mock.patch("airflow.providers.standard.triggers.file.os.path.isfile", return_value=True) @mock.patch("airflow.providers.standard.sensors.filesystem.FileSensor.poke", return_value=False) def test_cli_test_with_deferrable_operator( self, mock_pock, mock_is_file, mock_glob, mock_getmtime, caplog diff --git a/tests/jobs/test_triggerer_job.py b/tests/jobs/test_triggerer_job.py index bae777155bbcb..d9c47aa70edce 100644 --- a/tests/jobs/test_triggerer_job.py +++ b/tests/jobs/test_triggerer_job.py @@ -37,8 +37,8 @@ from airflow.models.dag import DAG from airflow.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator +from airflow.providers.standard.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger from airflow.triggers.base import TriggerEvent -from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger from airflow.triggers.testing import FailureTrigger, SuccessTrigger from airflow.utils import timezone from airflow.utils.log.logging_mixin import RedirectStdHandler diff --git a/tests/operators/test_trigger_dagrun.py b/tests/operators/test_trigger_dagrun.py index 8a7f274499cd3..9f4eef47a1d0e 100644 --- a/tests/operators/test_trigger_dagrun.py +++ b/tests/operators/test_trigger_dagrun.py @@ -31,8 +31,8 @@ from airflow.models.log import Log from airflow.models.taskinstance import TaskInstance from airflow.providers.standard.operators.trigger_dagrun import TriggerDagRunOperator +from airflow.providers.standard.triggers.external_task import DagStateTrigger from airflow.settings import TracebackSessionForTests -from airflow.triggers.external_task import DagStateTrigger from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import DagRunState, State, TaskInstanceState diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index e03ceeed01960..e47e88d62ef70 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -39,12 +39,12 @@ from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.python import PythonOperator from airflow.providers.standard.sensors.time import TimeSensor +from airflow.providers.standard.triggers.external_task import WorkflowTrigger from airflow.sensors.external_task import ( ExternalTaskMarker, ExternalTaskSensor, ) from airflow.serialization.serialized_objects import SerializedBaseOperator -from airflow.triggers.external_task import WorkflowTrigger from airflow.utils.hashlib_wrapper import md5 from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.state import DagRunState, State, TaskInstanceState diff --git a/tests/sensors/test_filesystem.py b/tests/sensors/test_filesystem.py index 0774b56c31cd8..22432e90a92f9 100644 --- a/tests/sensors/test_filesystem.py +++ b/tests/sensors/test_filesystem.py @@ -27,7 +27,7 @@ from airflow.exceptions import AirflowSensorTimeout, TaskDeferred from airflow.models.dag import DAG from airflow.providers.standard.sensors.filesystem import FileSensor -from airflow.triggers.file import FileTrigger +from airflow.providers.standard.triggers.file import FileTrigger from airflow.utils.timezone import datetime pytestmark = pytest.mark.db_test diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 69bd5259d7e51..555124f65b049 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -2204,7 +2204,7 @@ def test_start_trigger_args_in_serialized_dag(self): class TestOperator(BaseOperator): start_trigger_args = StartTriggerArgs( - trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger", + trigger_cls="airflow.providers.standard.triggers.temporal.TimeDeltaTrigger", trigger_kwargs={"delta": timedelta(seconds=1)}, next_method="execute_complete", next_kwargs=None, @@ -2247,7 +2247,7 @@ def execute_complete(self): assert tasks[0]["__var"]["start_trigger_args"] == { "__type": "START_TRIGGER_ARGS", - "trigger_cls": "airflow.triggers.temporal.TimeDeltaTrigger", + "trigger_cls": "airflow.providers.standard.triggers.temporal.TimeDeltaTrigger", # "trigger_kwargs": {"__type": "dict", "__var": {"delta": {"__type": "timedelta", "__var": 2.0}}}, "trigger_kwargs": {"__type": "dict", "__var": {"delta": {"__type": "timedelta", "__var": 2.0}}}, "next_method": "execute_complete",