From e9f544cc3fb1ac3d7709b3c54804dd6fdd510eca Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Thu, 28 Nov 2024 17:29:13 +0000 Subject: [PATCH] Remove AIP-44 configuration from the code (#44454) This change removes all configuration that controls AIP-44 behaviour. It does not yet remove all the related code, this will be a follow up but it removes all the controls that determine if AIP-44 is enabled or not and removes all the Traceback Session/Disabling of DB session modifications that were used in "database isolation" mode. The "database isolation" mode has been disabled in #44441 so there was no easy way to enable it anyoway - this change removes the capability to use database isolation mode completely. Part of #44436 --- airflow/__main__.py | 30 +- airflow/api_internal/internal_api_call.py | 150 +------- airflow/cli/cli_config.py | 50 ++- airflow/cli/commands/dag_processor_command.py | 4 - airflow/cli/commands/db_command.py | 3 - airflow/cli/commands/internal_api_command.py | 8 - airflow/cli/commands/task_command.py | 13 +- airflow/executors/executor_loader.py | 4 - airflow/models/taskinstance.py | 11 +- airflow/sensors/base.py | 4 +- airflow/serialization/serialized_objects.py | 17 +- airflow/settings.py | 162 --------- airflow/task/standard_task_runner.py | 13 - airflow/utils/cli.py | 3 +- airflow/utils/task_instance_session.py | 8 - airflow/www/app.py | 7 - .../src/airflow_breeze/utils/run_tests.py | 2 +- .../airflow_breeze/utils/selective_checks.py | 1 - .../tests/test_pytest_args_for_test_types.py | 5 +- .../celery/executors/celery_executor_utils.py | 5 - .../providers/edge/cli/edge_command.py | 25 -- .../providers/standard/operators/python.py | 5 - .../standard/operators/trigger_dagrun.py | 9 - .../edge/worker_api/routes/test_rpc_api.py | 298 ---------------- .../tests/standard/operators/test_python.py | 25 -- scripts/ci/docker-compose/devcontainer.env | 1 - scripts/cov/other_coverage.py | 1 - scripts/cov/restapi_coverage.py | 2 +- tests/api_internal/__init__.py | 16 - tests/api_internal/endpoints/__init__.py | 16 - .../endpoints/test_rpc_api_endpoint.py | 250 ------------- tests/api_internal/test_internal_api_call.py | 327 ------------------ .../cli/commands/test_internal_api_command.py | 223 ------------ tests/core/test_settings.py | 22 +- tests/core/test_sqlalchemy_config.py | 5 - tests/models/test_taskinstance.py | 3 - tests/operators/test_trigger_dagrun.py | 3 - tests/serialization/test_pydantic_models.py | 295 ---------------- .../serialization/test_serialized_objects.py | 4 - tests_common/pytest_plugin.py | 1 - tests_common/test_utils/compat.py | 5 - 41 files changed, 41 insertions(+), 1995 deletions(-) delete mode 100644 providers/tests/edge/worker_api/routes/test_rpc_api.py delete mode 100644 tests/api_internal/__init__.py delete mode 100644 tests/api_internal/endpoints/__init__.py delete mode 100644 tests/api_internal/endpoints/test_rpc_api_endpoint.py delete mode 100644 tests/api_internal/test_internal_api_call.py delete mode 100644 tests/cli/commands/test_internal_api_command.py delete mode 100644 tests/serialization/test_pydantic_models.py diff --git a/airflow/__main__.py b/airflow/__main__.py index 8fbcd7e777640..bfebc63946ef6 100644 --- a/airflow/__main__.py +++ b/airflow/__main__.py @@ -22,7 +22,6 @@ from __future__ import annotations import os -from argparse import Namespace import argcomplete @@ -36,8 +35,7 @@ # any possible import cycles with settings downstream. from airflow import configuration from airflow.cli import cli_parser -from airflow.configuration import AirflowConfigParser, write_webserver_configuration_if_needed -from airflow.exceptions import AirflowException +from airflow.configuration import write_webserver_configuration_if_needed def main(): @@ -57,34 +55,8 @@ def main(): conf = write_default_airflow_configuration_if_needed() if args.subcommand in ["webserver", "internal-api", "worker"]: write_webserver_configuration_if_needed(conf) - configure_internal_api(args, conf) - args.func(args) -def configure_internal_api(args: Namespace, conf: AirflowConfigParser): - if conf.getboolean("core", "database_access_isolation", fallback=False): - if args.subcommand in ["worker", "dag-processor", "triggerer", "run"]: - # Untrusted components - if "AIRFLOW__DATABASE__SQL_ALCHEMY_CONN" in os.environ: - # make sure that the DB is not available for the components that should not access it - os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_CONN"] = "none://" - conf.set("database", "sql_alchemy_conn", "none://") - from airflow.api_internal.internal_api_call import InternalApiConfig - - InternalApiConfig.set_use_internal_api(args.subcommand) - else: - # Trusted components (this setting is mostly for Breeze where db_isolation and DB are both set - db_connection_url = conf.get("database", "sql_alchemy_conn") - if not db_connection_url or db_connection_url == "none://": - raise AirflowException( - f"Running trusted components {args.subcommand} in db isolation mode " - f"requires connection to be configured via database/sql_alchemy_conn." - ) - from airflow.api_internal.internal_api_call import InternalApiConfig - - InternalApiConfig.set_use_database_access(args.subcommand) - - if __name__ == "__main__": main() diff --git a/airflow/api_internal/internal_api_call.py b/airflow/api_internal/internal_api_call.py index 064834d7c8673..4c0b613b78b4c 100644 --- a/airflow/api_internal/internal_api_call.py +++ b/airflow/api_internal/internal_api_call.py @@ -17,23 +17,13 @@ from __future__ import annotations -import inspect -import json import logging from functools import wraps from http import HTTPStatus from typing import Callable, TypeVar -from urllib.parse import urlparse -import requests -import tenacity -from urllib3.exceptions import NewConnectionError - -from airflow.configuration import conf -from airflow.exceptions import AirflowConfigException, AirflowException -from airflow.settings import _ENABLE_AIP_44, force_traceback_session_for_untrusted_components +from airflow.exceptions import AirflowException from airflow.typing_compat import ParamSpec -from airflow.utils.jwt_signer import JWTSigner PS = ParamSpec("PS") RT = TypeVar("RT") @@ -49,145 +39,9 @@ def __init__(self, message: str, status_code: HTTPStatus): self.status_code = status_code -class InternalApiConfig: - """Stores and caches configuration for Internal API.""" - - _use_internal_api = False - _internal_api_endpoint = "" - - @staticmethod - def set_use_database_access(component: str): - """ - Block current component from using Internal API. - - All methods decorated with internal_api_call will always be executed locally.` - This mode is needed for "trusted" components like Scheduler, Webserver, Internal Api server - """ - InternalApiConfig._use_internal_api = False - if not _ENABLE_AIP_44: - raise RuntimeError("The AIP_44 is not enabled so you cannot use it. ") - logger.info( - "DB isolation mode. But this is a trusted component and DB connection is set. " - "Using database direct access when running %s.", - component, - ) - - @staticmethod - def set_use_internal_api(component: str, allow_tests_to_use_db: bool = False): - if not _ENABLE_AIP_44: - raise RuntimeError("The AIP_44 is not enabled so you cannot use it. ") - internal_api_url = conf.get("core", "internal_api_url") - url_conf = urlparse(internal_api_url) - api_path = url_conf.path - if api_path in ["", "/"]: - # Add the default path if not given in the configuration - api_path = "/internal_api/v1/rpcapi" - if url_conf.scheme not in ["http", "https"]: - raise AirflowConfigException("[core]internal_api_url must start with http:// or https://") - internal_api_endpoint = f"{url_conf.scheme}://{url_conf.netloc}{api_path}" - InternalApiConfig._use_internal_api = True - InternalApiConfig._internal_api_endpoint = internal_api_endpoint - logger.info("DB isolation mode. Using internal_api when running %s.", component) - force_traceback_session_for_untrusted_components(allow_tests_to_use_db=allow_tests_to_use_db) - - @staticmethod - def get_use_internal_api(): - return InternalApiConfig._use_internal_api - - @staticmethod - def get_internal_api_endpoint(): - return InternalApiConfig._internal_api_endpoint - - def internal_api_call(func: Callable[PS, RT]) -> Callable[PS, RT]: - """ - Allow methods to be executed in database isolation mode. - - If [core]database_access_isolation is true then such method are not executed locally, - but instead RPC call is made to Database API (aka Internal API). This makes some components - decouple from direct Airflow database access. - Each decorated method must be present in METHODS list in airflow.api_internal.endpoints.rpc_api_endpoint. - Only static methods can be decorated. This decorator must be before "provide_session". - - See [AIP-44](https://cwiki.apache.org/confluence/display/AIRFLOW/AIP-44+Airflow+Internal+API) - for more information . - """ - from requests.exceptions import ConnectionError - - def _is_retryable_exception(exception: BaseException) -> bool: - """ - Evaluate which exception types to retry. - - This is especially demanded for cases where an application gateway or Kubernetes ingress can - not find a running instance of a webserver hosting the API (HTTP 502+504) or when the - HTTP request fails in general on network level. - - Note that we want to fail on other general errors on the webserver not to send bad requests in an endless loop. - """ - retryable_status_codes = (HTTPStatus.BAD_GATEWAY, HTTPStatus.GATEWAY_TIMEOUT) - return ( - isinstance(exception, AirflowHttpException) - and exception.status_code in retryable_status_codes - or isinstance(exception, (ConnectionError, NewConnectionError)) - ) - - @tenacity.retry( - stop=tenacity.stop_after_attempt(10), - wait=tenacity.wait_exponential(min=1), - retry=tenacity.retry_if_exception(_is_retryable_exception), - before_sleep=tenacity.before_log(logger, logging.WARNING), - ) - def make_jsonrpc_request(method_name: str, params_json: str) -> bytes: - signer = JWTSigner( - secret_key=conf.get("core", "internal_api_secret_key"), - expiration_time_in_seconds=conf.getint("core", "internal_api_clock_grace", fallback=30), - audience="api", - ) - headers = { - "Content-Type": "application/json", - "Accept": "application/json", - "Authorization": signer.generate_signed_token({"method": method_name}), - } - data = {"jsonrpc": "2.0", "method": method_name, "params": params_json} - internal_api_endpoint = InternalApiConfig.get_internal_api_endpoint() - response = requests.post(url=internal_api_endpoint, data=json.dumps(data), headers=headers) - if response.status_code != 200: - raise AirflowHttpException( - f"Got {response.status_code}:{response.reason} when sending " - f"the internal api request: {response.text}", - HTTPStatus(response.status_code), - ) - return response.content - @wraps(func) def wrapper(*args, **kwargs): - use_internal_api = InternalApiConfig.get_use_internal_api() - if not use_internal_api: - return func(*args, **kwargs) - import traceback - - tb = traceback.extract_stack() - if any(filename.endswith("conftest.py") for filename, _, _, _ in tb): - # This is a test fixture, we should not use internal API for it - return func(*args, **kwargs) - - from airflow.serialization.serialized_objects import BaseSerialization # avoid circular import - - bound = inspect.signature(func).bind(*args, **kwargs) - arguments_dict = dict(bound.arguments) - if "session" in arguments_dict: - del arguments_dict["session"] - if "cls" in arguments_dict: # used by @classmethod - del arguments_dict["cls"] - - args_dict = BaseSerialization.serialize(arguments_dict, use_pydantic_models=True) - method_name = f"{func.__module__}.{func.__qualname__}" - result = make_jsonrpc_request(method_name, args_dict) - if result is None or result == b"": - return None - result = BaseSerialization.deserialize(json.loads(result), use_pydantic_models=True) - if isinstance(result, (KeyError, AttributeError, AirflowException)): - raise result - return result + return func(*args, **kwargs) return wrapper diff --git a/airflow/cli/cli_config.py b/airflow/cli/cli_config.py index d03ebd312600e..21d09d0bca811 100644 --- a/airflow/cli/cli_config.py +++ b/airflow/cli/cli_config.py @@ -32,7 +32,6 @@ from airflow import settings from airflow.cli.commands.legacy_commands import check_legacy_command from airflow.configuration import conf -from airflow.settings import _ENABLE_AIP_44 from airflow.utils.cli import ColorMode from airflow.utils.module_loading import import_string from airflow.utils.state import DagRunState, JobState @@ -2071,32 +2070,31 @@ class GroupCommand(NamedTuple): ), ] -if _ENABLE_AIP_44: - core_commands.append( - ActionCommand( - name="internal-api", - help="Start an Airflow Internal API instance", - func=lazy_load_command("airflow.cli.commands.internal_api_command.internal_api"), - args=( - ARG_INTERNAL_API_PORT, - ARG_INTERNAL_API_WORKERS, - ARG_INTERNAL_API_WORKERCLASS, - ARG_INTERNAL_API_WORKER_TIMEOUT, - ARG_INTERNAL_API_HOSTNAME, - ARG_PID, - ARG_DAEMON, - ARG_STDOUT, - ARG_STDERR, - ARG_INTERNAL_API_ACCESS_LOGFILE, - ARG_INTERNAL_API_ERROR_LOGFILE, - ARG_INTERNAL_API_ACCESS_LOGFORMAT, - ARG_LOG_FILE, - ARG_SSL_CERT, - ARG_SSL_KEY, - ARG_DEBUG, - ), +core_commands.append( + ActionCommand( + name="internal-api", + help="Start an Airflow Internal API instance", + func=lazy_load_command("airflow.cli.commands.internal_api_command.internal_api"), + args=( + ARG_INTERNAL_API_PORT, + ARG_INTERNAL_API_WORKERS, + ARG_INTERNAL_API_WORKERCLASS, + ARG_INTERNAL_API_WORKER_TIMEOUT, + ARG_INTERNAL_API_HOSTNAME, + ARG_PID, + ARG_DAEMON, + ARG_STDOUT, + ARG_STDERR, + ARG_INTERNAL_API_ACCESS_LOGFILE, + ARG_INTERNAL_API_ERROR_LOGFILE, + ARG_INTERNAL_API_ACCESS_LOGFORMAT, + ARG_LOG_FILE, + ARG_SSL_CERT, + ARG_SSL_KEY, + ARG_DEBUG, ), - ) + ), +) def _remove_dag_id_opt(command: ActionCommand): diff --git a/airflow/cli/commands/dag_processor_command.py b/airflow/cli/commands/dag_processor_command.py index eea1c0db20dc5..042733976f81f 100644 --- a/airflow/cli/commands/dag_processor_command.py +++ b/airflow/cli/commands/dag_processor_command.py @@ -22,7 +22,6 @@ from datetime import timedelta from typing import Any -from airflow.api_internal.internal_api_call import InternalApiConfig from airflow.cli.commands.daemon_utils import run_command_with_daemon_option from airflow.configuration import conf from airflow.dag_processing.manager import DagFileProcessorManager, reload_configuration_for_dag_processing @@ -38,9 +37,6 @@ def _create_dag_processor_job_runner(args: Any) -> DagProcessorJobRunner: """Create DagFileProcessorProcess instance.""" processor_timeout_seconds: int = conf.getint("core", "dag_file_processor_timeout") processor_timeout = timedelta(seconds=processor_timeout_seconds) - if InternalApiConfig.get_use_internal_api(): - from airflow.models.renderedtifields import RenderedTaskInstanceFields # noqa: F401 - from airflow.models.trigger import Trigger # noqa: F401 return DagProcessorJobRunner( job=Job(), processor=DagFileProcessorManager( diff --git a/airflow/cli/commands/db_command.py b/airflow/cli/commands/db_command.py index cd7493212b4ee..ff268d1de1662 100644 --- a/airflow/cli/commands/db_command.py +++ b/airflow/cli/commands/db_command.py @@ -29,7 +29,6 @@ from tenacity import Retrying, stop_after_attempt, wait_fixed from airflow import settings -from airflow.api_internal.internal_api_call import InternalApiConfig from airflow.exceptions import AirflowException from airflow.utils import cli as cli_utils, db from airflow.utils.db import _REVISION_HEADS_MAP @@ -281,8 +280,6 @@ def shell(args): @providers_configuration_loaded def check(args): """Run a check command that checks if db is available.""" - if InternalApiConfig.get_use_internal_api(): - return retries: int = args.retry retry_delay: int = args.retry_delay diff --git a/airflow/cli/commands/internal_api_command.py b/airflow/cli/commands/internal_api_command.py index d1ab8eea86787..45c930b47fc0f 100644 --- a/airflow/cli/commands/internal_api_command.py +++ b/airflow/cli/commands/internal_api_command.py @@ -38,7 +38,6 @@ from sqlalchemy.engine.url import make_url from airflow import settings -from airflow.api_internal.internal_api_call import InternalApiConfig from airflow.cli.commands.daemon_utils import run_command_with_daemon_option from airflow.cli.commands.webserver_command import GunicornMonitor from airflow.configuration import conf @@ -222,13 +221,6 @@ def create_app(config=None, testing=False): if "SQLALCHEMY_ENGINE_OPTIONS" not in flask_app.config: flask_app.config["SQLALCHEMY_ENGINE_OPTIONS"] = settings.prepare_engine_args() - if conf.getboolean("core", "database_access_isolation", fallback=False): - InternalApiConfig.set_use_database_access("Gunicorn worker initialization") - else: - raise AirflowConfigException( - "The internal-api component should only be run when database_access_isolation is enabled." - ) - csrf = CSRFProtect() csrf.init_app(flask_app) diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index 3962592e9b2d5..4748aea2bbf68 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -35,7 +35,7 @@ from sqlalchemy import select from airflow import settings -from airflow.api_internal.internal_api_call import InternalApiConfig, internal_api_call +from airflow.api_internal.internal_api_call import internal_api_call from airflow.cli.simple_table import AirflowConsole from airflow.configuration import conf from airflow.exceptions import AirflowException, DagRunNotFound, TaskDeferred, TaskInstanceNotFound @@ -333,9 +333,6 @@ def _run_task_by_executor(args, dag: DAG, ti: TaskInstance) -> None: def _run_task_by_local_task_job(args, ti: TaskInstance | TaskInstancePydantic) -> TaskReturnCode | None: """Run LocalTaskJob, which monitors the raw task execution process.""" - if InternalApiConfig.get_use_internal_api(): - from airflow.models.renderedtifields import RenderedTaskInstanceFields # noqa: F401 - from airflow.models.trigger import Trigger # noqa: F401 job_runner = LocalTaskJobRunner( job=Job(dag_id=ti.dag_id), task_instance=ti, @@ -490,14 +487,6 @@ def task_run(args, dag: DAG | None = None) -> TaskReturnCode | None: log.info("Running %s on host %s", ti, hostname) - if not InternalApiConfig.get_use_internal_api(): - # IMPORTANT, have to re-configure ORM with the NullPool, otherwise, each "run" command may leave - # behind multiple open sleeping connections while heartbeating, which could - # easily exceed the database connection limit when - # processing hundreds of simultaneous tasks. - # this should be last thing before running, to reduce likelihood of an open session - # which can cause trouble if running process in a fork. - settings.reconfigure_orm(disable_connection_pool=True) task_return_code = None try: if args.interactive: diff --git a/airflow/executors/executor_loader.py b/airflow/executors/executor_loader.py index 84375a4baeb0b..8093566ab5abe 100644 --- a/airflow/executors/executor_loader.py +++ b/airflow/executors/executor_loader.py @@ -23,7 +23,6 @@ import os from typing import TYPE_CHECKING -from airflow.api_internal.internal_api_call import InternalApiConfig from airflow.exceptions import AirflowConfigException, UnknownExecutorException from airflow.executors.executor_constants import ( CELERY_EXECUTOR, @@ -293,9 +292,6 @@ def validate_database_executor_compatibility(cls, executor: type[BaseExecutor]) if os.environ.get("_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK") == "1": return - if InternalApiConfig.get_use_internal_api(): - return - from airflow.settings import engine # SQLite only works with single threaded executors diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 53e1925234ecb..591f3549bab39 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -28,7 +28,6 @@ import signal from collections import defaultdict from collections.abc import Collection, Generator, Iterable, Mapping -from contextlib import nullcontext from datetime import timedelta from enum import Enum from functools import cache @@ -73,7 +72,7 @@ from sqlalchemy_utils import UUIDType from airflow import settings -from airflow.api_internal.internal_api_call import InternalApiConfig, internal_api_call +from airflow.api_internal.internal_api_call import internal_api_call from airflow.assets.manager import asset_manager from airflow.configuration import conf from airflow.exceptions import ( @@ -757,7 +756,7 @@ def _execute_callable(context: Context, **execute_callable_kwargs): raise else: result = _execute_callable(context=context, **execute_callable_kwargs) - cm = nullcontext() if InternalApiConfig.get_use_internal_api() else create_session() + cm = create_session() with cm as session_or_null: if task_to_execute.do_xcom_push: xcom_value = result @@ -859,10 +858,6 @@ def _refresh_from_db( :meta private: """ - if not InternalApiConfig.get_use_internal_api(): - if session and task_instance in session: - session.refresh(task_instance, TaskInstance.__mapper__.column_attrs.keys()) - ti = TaskInstance.get_task_instance( dag_id=task_instance.dag_id, task_id=task_instance.task_id, @@ -2966,7 +2961,7 @@ def _execute_task(self, context: Context, task_orig: Operator): return _execute_task(self, context, task_orig) def update_heartbeat(self): - cm = nullcontext() if InternalApiConfig.get_use_internal_api() else create_session() + cm = create_session() with cm as session_or_null: _update_ti_heartbeat(self.id, timezone.utcnow(), session_or_null) diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py index f117b97d0ce5a..1c56aa42005a0 100644 --- a/airflow/sensors/base.py +++ b/airflow/sensors/base.py @@ -29,7 +29,7 @@ from sqlalchemy import select from airflow import settings -from airflow.api_internal.internal_api_call import InternalApiConfig, internal_api_call +from airflow.api_internal.internal_api_call import internal_api_call from airflow.configuration import conf from airflow.exceptions import ( AirflowException, @@ -59,8 +59,6 @@ @functools.cache def _is_metadatabase_mysql() -> bool: - if InternalApiConfig.get_use_internal_api(): - return False if settings.engine is None: raise AirflowException("Must initialize ORM first") return settings.engine.url.get_backend_name() == "mysql" diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index ced1bdd6837aa..1a13430e2fcb5 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -73,7 +73,7 @@ from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic from airflow.serialization.pydantic.tasklog import LogTemplatePydantic from airflow.serialization.pydantic.trigger import TriggerPydantic -from airflow.settings import _ENABLE_AIP_44, DAGS_FOLDER, json +from airflow.settings import DAGS_FOLDER, json from airflow.task.priority_strategy import ( PriorityWeightStrategy, airflow_priority_weight_strategies, @@ -633,11 +633,6 @@ def serialize( :meta private: """ - if use_pydantic_models and not _ENABLE_AIP_44: - raise RuntimeError( - "Setting use_pydantic_models = True requires AIP-44 (in progress) feature flag to be true. " - "This parameter will be removed eventually when new serialization is used by AIP-44" - ) if cls._is_primitive(var): # enum.IntEnum is an int instance, it causes json dumps error so we use its value. if isinstance(var, enum.Enum): @@ -764,7 +759,7 @@ def serialize( obj = cls.serialize(v, strict=strict, use_pydantic_models=use_pydantic_models) d[str(k)] = obj return cls._encode(d, type_=DAT.TASK_CONTEXT) - elif use_pydantic_models and _ENABLE_AIP_44: + elif use_pydantic_models: def _pydantic_model_dump(model_cls: type[BaseModel], var: Any) -> dict[str, Any]: return model_cls.model_validate(var).model_dump(mode="json") # type: ignore[attr-defined] @@ -795,12 +790,6 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any: :meta private: """ - # JSON primitives (except for dict) are not encoded. - if use_pydantic_models and not _ENABLE_AIP_44: - raise RuntimeError( - "Setting use_pydantic_models = True requires AIP-44 (in progress) feature flag to be true. " - "This parameter will be removed eventually when new serialization is used by AIP-44" - ) if cls._is_primitive(encoded_var): return encoded_var elif isinstance(encoded_var, list): @@ -892,7 +881,7 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any: return DagCallbackRequest.from_json(var) elif type_ == DAT.TASK_INSTANCE_KEY: return TaskInstanceKey(**var) - elif use_pydantic_models and _ENABLE_AIP_44: + elif use_pydantic_models: return _type_to_class[type_][0].model_validate(var) elif type_ == DAT.ARG_NOT_SET: return NOTSET diff --git a/airflow/settings.py b/airflow/settings.py index d4d6346f3f131..afde6d68d7df6 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -23,7 +23,6 @@ import logging import os import sys -import traceback import warnings from importlib import metadata from typing import TYPE_CHECKING, Any, Callable @@ -301,39 +300,6 @@ def get_bind( pass -def get_cleaned_traceback(stack_summary: traceback.StackSummary) -> str: - clened_traceback = [ - frame - for frame in stack_summary[:-2] - if "/_pytest" not in frame.filename and "/pluggy" not in frame.filename - ] - return "".join(traceback.format_list(clened_traceback)) - - -class TracebackSession: - """ - Session that throws error when you try to use it. - - Also stores stack at instantiation call site. - - :meta private: - """ - - def __init__(self): - self.traceback = traceback.extract_stack() - - def __getattr__(self, item): - raise RuntimeError( - "TracebackSession object was used but internal API is enabled. " - "You'll need to ensure you are making only RPC calls with this object. " - "The stack list below will show where the TracebackSession object was created." - + get_cleaned_traceback(self.traceback) - ) - - def remove(*args, **kwargs): - pass - - AIRFLOW_PATH = os.path.dirname(os.path.dirname(__file__)) AIRFLOW_TESTS_PATH = os.path.join(AIRFLOW_PATH, "tests") AIRFLOW_SETTINGS_PATH = os.path.join(AIRFLOW_PATH, "airflow", "settings.py") @@ -341,100 +307,6 @@ def remove(*args, **kwargs): AIRFLOW_MODELS_BASEOPERATOR_PATH = os.path.join(AIRFLOW_PATH, "airflow", "models", "baseoperator.py") -class TracebackSessionForTests: - """ - Session that throws error when you try to create a session outside of the test code. - - When we run our tests in "db isolation" mode we expect that "airflow" code will never create - a session on its own and internal_api server is used for all calls but the test code might use - the session to setup and teardown in the DB so that the internal API server accesses it. - - :meta private: - """ - - db_session_class = None - allow_db_access = False - """For pytests to create/prepare stuff where explicit DB access it needed""" - - def __init__(self): - self.current_db_session = TracebackSessionForTests.db_session_class() - self.created_traceback = traceback.extract_stack() - - def __getattr__(self, item): - test_code, frame_summary = self.is_called_from_test_code() - if self.allow_db_access or test_code: - return getattr(self.current_db_session, item) - raise RuntimeError( - "TracebackSessionForTests object was used but internal API is enabled. " - "Only test code is allowed to use this object.\n" - f"Called from:\n {frame_summary.filename}: {frame_summary.lineno}\n" - f" {frame_summary.line}\n\n" - "You'll need to ensure you are making only RPC calls with this object. " - "The stack list below will show where the TracebackSession object was called:\n" - + get_cleaned_traceback(self.traceback) - + "\n\nThe stack list below will show where the TracebackSession object was created:\n" - + get_cleaned_traceback(self.created_traceback) - ) - - def remove(*args, **kwargs): - pass - - @staticmethod - def set_allow_db_access(session, flag: bool): - """Temporarily, e.g. for pytests allow access to DB to prepare stuff.""" - if isinstance(session, TracebackSessionForTests): - session.allow_db_access = flag - - def is_called_from_test_code(self) -> tuple[bool, traceback.FrameSummary | None]: - """ - Check if the traceback session was used from the test code. - - This is done by checking if the first "airflow" filename in the traceback - is "airflow/tests" or "regular airflow". - - :meta: private - :return: True if the object was created from test code, False otherwise. - """ - self.traceback = traceback.extract_stack() - airflow_frames = [ - tb - for tb in self.traceback - if tb.filename.startswith(AIRFLOW_PATH) - and not tb.filename == AIRFLOW_SETTINGS_PATH - and not tb.filename == AIRFLOW_UTILS_SESSION_PATH - ] - if any( - filename.endswith("conftest.py") - or filename.endswith("dev/airflow_common_pytest/test_utils/db.py") - for filename, _, _, _ in airflow_frames - ): - # This is a fixture call or testing utilities - return True, None - if ( - len(airflow_frames) >= 2 - and airflow_frames[-2].filename.startswith(AIRFLOW_TESTS_PATH) - and airflow_frames[-1].filename == AIRFLOW_MODELS_BASEOPERATOR_PATH - and airflow_frames[-1].name == "run" - ): - # This is baseoperator run method that is called directly from the test code and this is - # usual pattern where we create a session in the test code to create dag_runs for tests. - # If `run` code will be run inside a real "airflow" code the stack trace would be longer - # and it would not be directly called from the test code. Also if subsequently any of the - # run_task() method called later from the task code will attempt to execute any DB - # method, the stack trace will be longer and we will catch it as "illegal" call. - return True, None - for tb in airflow_frames[::-1]: - if tb.filename.startswith(AIRFLOW_PATH): - if tb.filename.startswith(AIRFLOW_TESTS_PATH): - # this is a session created directly in the test code - return True, None - else: - return False, tb - # if it is from elsewhere.... Why???? We should return False in order to crash to find out - # The traceback line will be always 3rd (two bottom ones are Airflow) - return False, self.traceback[-2] - - def _is_sqlite_db_path_relative(sqla_conn_str: str) -> bool: """Determine whether the database connection URI specifies a relative path.""" # Check for non-empty connection string: @@ -477,11 +349,6 @@ def configure_orm(disable_connection_pool=False, pool_class=None): Session = SkipDBTestsSession engine = None return - if conf.get("database", "sql_alchemy_conn") == "none://": - from airflow.api_internal.internal_api_call import InternalApiConfig - - InternalApiConfig.set_use_internal_api("ORM reconfigured in forked process.") - return log.debug("Setting up DB connection pool (PID %s)", os.getpid()) engine_args = prepare_engine_args(disable_connection_pool, pool_class) @@ -525,25 +392,6 @@ def _session_maker(_engine): Session = scoped_session(NonScopedSession) -def force_traceback_session_for_untrusted_components(allow_tests_to_use_db=False): - log.info("Forcing TracebackSession for untrusted components.") - global Session - global engine - if allow_tests_to_use_db: - old_session_class = Session - Session = TracebackSessionForTests - TracebackSessionForTests.db_session_class = old_session_class - else: - try: - dispose_orm() - except NameError: - # This exception might be thrown in case the ORM has not been initialized yet. - pass - else: - Session = TracebackSession - engine = None - - DEFAULT_ENGINE_ARGS = { "postgresql": { "executemany_mode": "values_plus_batch", @@ -887,13 +735,3 @@ def is_usage_data_collection_enabled() -> bool: AIRFLOW_MOVED_TABLE_PREFIX = "_airflow_moved" DAEMON_UMASK: str = conf.get("core", "daemon_umask", fallback="0o077") - -# AIP-44: internal_api (experimental) -# This feature is not complete yet, so we disable it by default. -_ENABLE_AIP_44: bool = os.environ.get("AIRFLOW_ENABLE_AIP_44", "false").lower() in { - "true", - "t", - "yes", - "y", - "1", -} diff --git a/airflow/task/standard_task_runner.py b/airflow/task/standard_task_runner.py index c446bccac9606..b4d05fc475389 100644 --- a/airflow/task/standard_task_runner.py +++ b/airflow/task/standard_task_runner.py @@ -131,28 +131,15 @@ def _start_by_fork(self): self.log.info("Started process %d to run task", pid) return psutil.Process(pid) else: - from airflow.api_internal.internal_api_call import InternalApiConfig - from airflow.configuration import conf - - if conf.getboolean("core", "database_access_isolation", fallback=False): - InternalApiConfig.set_use_internal_api("Forked task runner") # Start a new process group set_new_process_group() signal.signal(signal.SIGINT, signal.SIG_DFL) signal.signal(signal.SIGTERM, signal.SIG_DFL) - from airflow import settings from airflow.cli.cli_parser import get_parser from airflow.sentry import Sentry - if not InternalApiConfig.get_use_internal_api(): - # Force a new SQLAlchemy session. We can't share open DB handles - # between process. The cli code will re-create this as part of its - # normal startup - settings.engine.pool.dispose() - settings.engine.dispose() - parser = get_parser() # [1:] - remove "airflow" from the start of the command args = parser.parse_args(self._command[1:]) diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py index c3cd1faf6b26d..c1aba4da8f580 100644 --- a/airflow/utils/cli.py +++ b/airflow/utils/cli.py @@ -34,7 +34,6 @@ import re2 from airflow import settings -from airflow.api_internal.internal_api_call import InternalApiConfig from airflow.exceptions import AirflowException from airflow.utils import cli_action_loggers, timezone from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler @@ -102,7 +101,7 @@ def wrapper(*args, **kwargs): handler.setLevel(logging.DEBUG) try: # Check and run migrations if necessary - if check_db and not InternalApiConfig.get_use_internal_api(): + if check_db: from airflow.configuration import conf from airflow.utils.db import check_and_run_migrations, synchronize_log_template diff --git a/airflow/utils/task_instance_session.py b/airflow/utils/task_instance_session.py index 6234463f0d879..bb9741bf52566 100644 --- a/airflow/utils/task_instance_session.py +++ b/airflow/utils/task_instance_session.py @@ -23,8 +23,6 @@ from typing import TYPE_CHECKING from airflow import settings -from airflow.api_internal.internal_api_call import InternalApiConfig -from airflow.settings import TracebackSession if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -37,9 +35,6 @@ def get_current_task_instance_session() -> Session: global __current_task_instance_session if not __current_task_instance_session: - if InternalApiConfig.get_use_internal_api(): - __current_task_instance_session = TracebackSession() - return __current_task_instance_session log.warning("No task session set for this task. Continuing but this likely causes a resource leak.") log.warning("Please report this and stacktrace below to https://github.com/apache/airflow/issues") for filename, line_number, name, line in traceback.extract_stack(): @@ -52,9 +47,6 @@ def get_current_task_instance_session() -> Session: @contextlib.contextmanager def set_current_task_instance_session(session: Session): - if InternalApiConfig.get_use_internal_api(): - yield - return global __current_task_instance_session if __current_task_instance_session: raise RuntimeError( diff --git a/airflow/www/app.py b/airflow/www/app.py index 3409510b5a1a6..91b23875dcda1 100644 --- a/airflow/www/app.py +++ b/airflow/www/app.py @@ -27,12 +27,10 @@ from sqlalchemy.engine.url import make_url from airflow import settings -from airflow.api_internal.internal_api_call import InternalApiConfig from airflow.configuration import conf from airflow.exceptions import AirflowConfigException from airflow.logging_config import configure_logging from airflow.models import import_all_models -from airflow.settings import _ENABLE_AIP_44 from airflow.utils.json import AirflowJsonProvider from airflow.www.extensions.init_appbuilder import init_appbuilder from airflow.www.extensions.init_appbuilder_links import init_appbuilder_links @@ -125,9 +123,6 @@ def create_app(config=None, testing=False): flask_app.json_provider_class = AirflowJsonProvider flask_app.json = AirflowJsonProvider(flask_app) - if conf.getboolean("core", "database_access_isolation", fallback=False): - InternalApiConfig.set_use_database_access("Gunicorn worker initialization") - csrf.init_app(flask_app) init_wsgi_middleware(flask_app) @@ -160,8 +155,6 @@ def create_app(config=None, testing=False): init_error_handlers(flask_app) init_api_connexion(flask_app) if conf.getboolean("webserver", "run_internal_api", fallback=False): - if not _ENABLE_AIP_44: - raise RuntimeError("The AIP_44 is not enabled so you cannot use it.") init_api_internal(flask_app) init_api_auth_provider(flask_app) init_api_error_handlers(flask_app) # needs to be after all api inits to let them add their path first diff --git a/dev/breeze/src/airflow_breeze/utils/run_tests.py b/dev/breeze/src/airflow_breeze/utils/run_tests.py index d92c5ad033e6f..d28194bbe0381 100644 --- a/dev/breeze/src/airflow_breeze/utils/run_tests.py +++ b/dev/breeze/src/airflow_breeze/utils/run_tests.py @@ -154,7 +154,7 @@ def get_excluded_provider_args(python_version: str) -> list[str]: TEST_TYPE_CORE_MAP_TO_PYTEST_ARGS: dict[str, list[str]] = { "Always": ["tests/always"], - "API": ["tests/api", "tests/api_connexion", "tests/api_internal", "tests/api_fastapi"], + "API": ["tests/api", "tests/api_connexion", "tests/api_fastapi"], "CLI": ["tests/cli"], "Core": [ "tests/core", diff --git a/dev/breeze/src/airflow_breeze/utils/selective_checks.py b/dev/breeze/src/airflow_breeze/utils/selective_checks.py index 5f375e6dc7a33..53a53a5015c53 100644 --- a/dev/breeze/src/airflow_breeze/utils/selective_checks.py +++ b/dev/breeze/src/airflow_breeze/utils/selective_checks.py @@ -295,7 +295,6 @@ def __hash__(self): r"^airflow/api_fastapi/", r"^tests/api/", r"^tests/api_connexion/", - r"^tests/api_internal/", r"^tests/api_fastapi/", ], SelectiveCoreTestType.CLI: [ diff --git a/dev/breeze/tests/test_pytest_args_for_test_types.py b/dev/breeze/tests/test_pytest_args_for_test_types.py index 86a72a4de3838..be5138699722f 100644 --- a/dev/breeze/tests/test_pytest_args_for_test_types.py +++ b/dev/breeze/tests/test_pytest_args_for_test_types.py @@ -51,7 +51,7 @@ ( GroupOfTests.CORE, "API", - ["tests/api", "tests/api_connexion", "tests/api_internal", "tests/api_fastapi"], + ["tests/api", "tests/api_connexion", "tests/api_fastapi"], ), ( GroupOfTests.CORE, @@ -172,7 +172,7 @@ def test_pytest_args_for_missing_provider(): ( GroupOfTests.CORE, "API", - ["tests/api", "tests/api_connexion", "tests/api_internal", "tests/api_fastapi"], + ["tests/api", "tests/api_connexion", "tests/api_fastapi"], ), ( GroupOfTests.CORE, @@ -187,7 +187,6 @@ def test_pytest_args_for_missing_provider(): [ "tests/api", "tests/api_connexion", - "tests/api_internal", "tests/api_fastapi", "tests/cli", ], diff --git a/providers/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/src/airflow/providers/celery/executors/celery_executor_utils.py index 12ae2c91cc131..65f8dfbe5b85c 100644 --- a/providers/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -41,7 +41,6 @@ from sqlalchemy import select import airflow.settings as settings -from airflow.api_internal.internal_api_call import InternalApiConfig from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowTaskTimeout from airflow.executors.base_executor import BaseExecutor @@ -161,10 +160,6 @@ def _execute_in_fork(command_to_exec: CommandType, celery_task_id: str | None = try: from airflow.cli.cli_parser import get_parser - if not InternalApiConfig.get_use_internal_api(): - settings.engine.pool.dispose() - settings.engine.dispose() - parser = get_parser() # [1:] - remove "airflow" from the start of the command args = parser.parse_args(command_to_exec[1:]) diff --git a/providers/src/airflow/providers/edge/cli/edge_command.py b/providers/src/airflow/providers/edge/cli/edge_command.py index fbcf9323d0c41..3712049b20776 100644 --- a/providers/src/airflow/providers/edge/cli/edge_command.py +++ b/providers/src/airflow/providers/edge/cli/edge_command.py @@ -21,7 +21,6 @@ import os import platform import signal -import sys from dataclasses import dataclass from datetime import datetime from pathlib import Path @@ -32,7 +31,6 @@ from lockfile.pidlockfile import read_pid_from_pidfile, remove_existing_pidfile, write_pid_to_pidfile from airflow import __version__ as airflow_version, settings -from airflow.api_internal.internal_api_call import InternalApiConfig from airflow.cli.cli_config import ARG_PID, ARG_VERBOSE, ActionCommand, Arg from airflow.configuration import conf from airflow.exceptions import AirflowException @@ -59,29 +57,6 @@ ) -@providers_configuration_loaded -def force_use_internal_api_on_edge_worker(): - """ - Ensure that the environment is configured for the internal API without needing to declare it outside. - - This is only required for an Edge worker and must to be done before the Click CLI wrapper is initiated. - That is because the CLI wrapper will attempt to establish a DB connection, which will fail before the - function call can take effect. In an Edge worker, we need to "patch" the environment before starting. - """ - if "airflow" in sys.argv[0] and sys.argv[1:3] == ["edge", "worker"]: - api_url = conf.get("edge", "api_url") - if not api_url: - raise SystemExit("Error: API URL is not configured, please correct configuration.") - logger.info("Starting worker with API endpoint %s", api_url) - # export Edge API to be used for internal API - os.environ["AIRFLOW_ENABLE_AIP_44"] = "True" - os.environ["AIRFLOW__CORE__INTERNAL_API_URL"] = api_url - InternalApiConfig.set_use_internal_api("edge-worker") - - -force_use_internal_api_on_edge_worker() - - def _hostname() -> str: if IS_WINDOWS: return platform.uname().node diff --git a/providers/src/airflow/providers/standard/operators/python.py b/providers/src/airflow/providers/standard/operators/python.py index c1865f0132bf6..264cfc4e7cb7a 100644 --- a/providers/src/airflow/providers/standard/operators/python.py +++ b/providers/src/airflow/providers/standard/operators/python.py @@ -50,7 +50,6 @@ from airflow.operators.branch import BranchMixIn from airflow.providers.standard.utils.python_virtualenv import prepare_virtualenv, write_python_script from airflow.providers.standard.utils.version_references import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS -from airflow.settings import _ENABLE_AIP_44 from airflow.typing_compat import Literal from airflow.utils import hashlib_wrapper from airflow.utils.context import context_copy_partial, context_merge @@ -524,10 +523,6 @@ def _execute_python_callable_in_subprocess(self, python_path: Path): self._write_args(input_path) self._write_string_args(string_args_path) - if self.use_airflow_context and not _ENABLE_AIP_44: - error_msg = "`get_current_context()` needs to be used with AIP-44 enabled." - raise AirflowException(error_msg) - jinja_context = { "op_args": self.op_args, "op_kwargs": op_kwargs, diff --git a/providers/src/airflow/providers/standard/operators/trigger_dagrun.py b/providers/src/airflow/providers/standard/operators/trigger_dagrun.py index fcc066778330b..e5f7aca313f43 100644 --- a/providers/src/airflow/providers/standard/operators/trigger_dagrun.py +++ b/providers/src/airflow/providers/standard/operators/trigger_dagrun.py @@ -27,7 +27,6 @@ from sqlalchemy.orm.exc import NoResultFound from airflow.api.common.trigger_dag import trigger_dag -from airflow.api_internal.internal_api_call import InternalApiConfig from airflow.configuration import conf from airflow.exceptions import ( AirflowException, @@ -182,14 +181,6 @@ def __init__( self.logical_date = logical_date def execute(self, context: Context): - if InternalApiConfig.get_use_internal_api(): - if self.reset_dag_run: - raise AirflowException("Parameter reset_dag_run=True is broken with Database Isolation Mode.") - if self.wait_for_completion: - raise AirflowException( - "Parameter wait_for_completion=True is broken with Database Isolation Mode." - ) - if isinstance(self.logical_date, datetime.datetime): parsed_logical_date = self.logical_date elif isinstance(self.logical_date, str): diff --git a/providers/tests/edge/worker_api/routes/test_rpc_api.py b/providers/tests/edge/worker_api/routes/test_rpc_api.py deleted file mode 100644 index a19ed59e8b0e2..0000000000000 --- a/providers/tests/edge/worker_api/routes/test_rpc_api.py +++ /dev/null @@ -1,298 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import json -from collections.abc import Generator -from typing import TYPE_CHECKING -from unittest import mock - -import pytest -from packaging.version import Version - -from airflow import __version__ as airflow_version -from airflow.configuration import conf -from airflow.models.baseoperator import BaseOperator -from airflow.models.connection import Connection -from airflow.models.dagrun import DagRun -from airflow.models.taskinstance import TaskInstance -from airflow.models.xcom import XCom -from airflow.operators.empty import EmptyOperator -from airflow.providers.edge.models.edge_job import EdgeJob -from airflow.providers.edge.models.edge_logs import EdgeLogs -from airflow.providers.edge.models.edge_worker import EdgeWorker -from airflow.providers.edge.worker_api.routes.rpc_api import _initialize_method_map -from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic -from airflow.serialization.serialized_objects import BaseSerialization -from airflow.settings import _ENABLE_AIP_44 -from airflow.utils.jwt_signer import JWTSigner -from airflow.utils.state import State -from airflow.www import app - -from tests_common.test_utils.decorators import dont_initialize_flask_app_submodules -from tests_common.test_utils.mock_plugins import mock_plugin_manager - -AIRFLOW_VERSION = Version(airflow_version) -AIRFLOW_V_3_0_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("3.0.0") - -# Note: Sounds a bit strange to disable internal API tests in isolation mode but... -# As long as the test is modelled to run its own internal API endpoints, it is conflicting -# to the test setup with a dedicated internal API server. -pytestmark = pytest.mark.db_test - - -def test_initialize_method_map(): - method_map = _initialize_method_map() - assert len(method_map) > 70 - for method in [ - # Test some basics - XCom.get_value, - XCom.get_one, - XCom.clear, - XCom.set, - DagRun.get_previous_dagrun, - DagRun.get_previous_scheduled_dagrun, - DagRun.get_task_instances, - DagRun.fetch_task_instance, - # Test some for Edge - EdgeJob.reserve_task, - EdgeJob.set_state, - EdgeLogs.push_logs, - EdgeWorker.register_worker, - EdgeWorker.set_state, - ]: - method_key = f"{method.__module__}.{method.__qualname__}" - assert method_key in method_map.keys() - - -if TYPE_CHECKING: - from flask import Flask - -TEST_METHOD_NAME = "test_method" -TEST_METHOD_WITH_LOG_NAME = "test_method_with_log" -TEST_API_ENDPOINT = "/edge_worker/v1/rpcapi" - -mock_test_method = mock.MagicMock() - -pytest.importorskip("pydantic", minversion="2.0.0") - - -def equals(a, b) -> bool: - return a == b - - -# Tests are written for Airflow 2.10, so we skip them for Airflow 3.0+ -# Unfortunately pytest fails in collection, therefore need a hard switch -if not AIRFLOW_V_3_0_PLUS: - - @pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled") - class TestRpcApiEndpointV2: - @pytest.fixture(scope="session") - def minimal_app_for_edge_api(self) -> Flask: - @dont_initialize_flask_app_submodules( - skip_all_except=[ - "init_api_auth", # This is needed for Airflow 2.10 compat tests - "init_appbuilder", - "init_plugins", - ] - ) - def factory() -> Flask: - import airflow.providers.edge.plugins.edge_executor_plugin as plugin_module - - class TestingEdgeExecutorPlugin(plugin_module.EdgeExecutorPlugin): - flask_blueprints = [ - plugin_module._get_airflow_2_api_endpoint(), - plugin_module.template_bp, - ] - - testing_edge_plugin = TestingEdgeExecutorPlugin() - assert len(testing_edge_plugin.flask_blueprints) > 0 - with mock_plugin_manager(plugins=[testing_edge_plugin]): - return app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore - - return factory() - - @pytest.fixture - def setup_attrs(self, minimal_app_for_edge_api: Flask) -> Generator: - self.app = minimal_app_for_edge_api - self.client = self.app.test_client() # type:ignore - mock_test_method.reset_mock() - mock_test_method.side_effect = None - with mock.patch( - "airflow.providers.edge.worker_api.routes.rpc_api._initialize_method_map" - ) as mock_initialize_method_map: - mock_initialize_method_map.return_value = { - TEST_METHOD_NAME: mock_test_method, - } - yield mock_initialize_method_map - - @pytest.fixture - def signer(self) -> JWTSigner: - return JWTSigner( - secret_key=conf.get("core", "internal_api_secret_key"), - expiration_time_in_seconds=conf.getint("core", "internal_api_clock_grace", fallback=30), - audience="api", - ) - - @pytest.mark.parametrize( - "input_params, method_result, result_cmp_func, method_params", - [ - ({}, None, lambda got, _: got == b"", {}), - ({}, "test_me", equals, {}), - ( - BaseSerialization.serialize({"dag_id": 15, "task_id": "fake-task"}), - ("dag_id_15", "fake-task", 1), - equals, - {"dag_id": 15, "task_id": "fake-task"}, - ), - ( - {}, - TaskInstance(task=EmptyOperator(task_id="task"), run_id="run_id", state=State.RUNNING), - lambda a, b: a.model_dump() == TaskInstancePydantic.model_validate(b).model_dump() - and isinstance(a.task, BaseOperator), - {}, - ), - ( - {}, - Connection(conn_id="test_conn", conn_type="http", host="", password=""), - lambda a, b: a.get_uri() == b.get_uri() and a.conn_id == b.conn_id, - {}, - ), - ], - ) - def test_method( - self, input_params, method_result, result_cmp_func, method_params, setup_attrs, signer: JWTSigner - ): - mock_test_method.return_value = method_result - headers = { - "Content-Type": "application/json", - "Accept": "application/json", - "Authorization": signer.generate_signed_token({"method": TEST_METHOD_NAME}), - } - input_data = { - "jsonrpc": "2.0", - "method": TEST_METHOD_NAME, - "params": input_params, - } - response = self.client.post( - TEST_API_ENDPOINT, - headers=headers, - data=json.dumps(input_data), - ) - assert response.status_code == 200 - if method_result: - response_data = BaseSerialization.deserialize( - json.loads(response.data), use_pydantic_models=True - ) - else: - response_data = response.data - - assert result_cmp_func(response_data, method_result) - - mock_test_method.assert_called_once_with(**method_params, session=mock.ANY) - - def test_method_with_exception(self, setup_attrs, signer: JWTSigner): - headers = { - "Content-Type": "application/json", - "Accept": "application/json", - "Authorization": signer.generate_signed_token({"method": TEST_METHOD_NAME}), - } - mock_test_method.side_effect = ValueError("Error!!!") - data = {"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": {}} - - response = self.client.post(TEST_API_ENDPOINT, headers=headers, data=json.dumps(data)) - assert response.status_code == 500 - assert response.data, b"Error executing method: test_method." - mock_test_method.assert_called_once() - - def test_unknown_method(self, setup_attrs, signer: JWTSigner): - UNKNOWN_METHOD = "i-bet-it-does-not-exist" - headers = { - "Content-Type": "application/json", - "Accept": "application/json", - "Authorization": signer.generate_signed_token({"method": UNKNOWN_METHOD}), - } - data = {"jsonrpc": "2.0", "method": UNKNOWN_METHOD, "params": {}} - - response = self.client.post(TEST_API_ENDPOINT, headers=headers, data=json.dumps(data)) - assert response.status_code == 400 - assert response.data.startswith(b"Unrecognized method: i-bet-it-does-not-exist.") - mock_test_method.assert_not_called() - - def test_invalid_jsonrpc(self, setup_attrs, signer: JWTSigner): - headers = { - "Content-Type": "application/json", - "Accept": "application/json", - "Authorization": signer.generate_signed_token({"method": TEST_METHOD_NAME}), - } - data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": {}} - - response = self.client.post(TEST_API_ENDPOINT, headers=headers, data=json.dumps(data)) - assert response.status_code == 400 - assert response.data.startswith(b"Expected jsonrpc 2.0 request.") - mock_test_method.assert_not_called() - - def test_missing_token(self, setup_attrs): - mock_test_method.return_value = None - - input_data = { - "jsonrpc": "2.0", - "method": TEST_METHOD_NAME, - "params": {}, - } - response = self.client.post( - TEST_API_ENDPOINT, - headers={"Content-Type": "application/json", "Accept": "application/json"}, - data=json.dumps(input_data), - ) - assert response.status_code == 403 - assert "Unable to authenticate API via token." in response.text - - def test_invalid_token(self, setup_attrs, signer: JWTSigner): - headers = { - "Content-Type": "application/json", - "Accept": "application/json", - "Authorization": signer.generate_signed_token({"method": "WRONG_METHOD_NAME"}), - } - data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": {}} - - response = self.client.post(TEST_API_ENDPOINT, headers=headers, data=json.dumps(data)) - assert response.status_code == 403 - assert "Bad Signature. Please use only the tokens provided by the API." in response.text - - def test_missing_accept(self, setup_attrs, signer: JWTSigner): - headers = { - "Content-Type": "application/json", - "Authorization": signer.generate_signed_token({"method": "WRONG_METHOD_NAME"}), - } - data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": {}} - - response = self.client.post(TEST_API_ENDPOINT, headers=headers, data=json.dumps(data)) - assert response.status_code == 403 - assert "Expected Accept: application/json" in response.text - - def test_wrong_accept(self, setup_attrs, signer: JWTSigner): - headers = { - "Content-Type": "application/json", - "Accept": "application/html", - "Authorization": signer.generate_signed_token({"method": "WRONG_METHOD_NAME"}), - } - data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": {}} - - response = self.client.post(TEST_API_ENDPOINT, headers=headers, data=json.dumps(data)) - assert response.status_code == 403 - assert "Expected Accept: application/json" in response.text diff --git a/providers/tests/standard/operators/test_python.py b/providers/tests/standard/operators/test_python.py index 29db9a36897f5..a72073537aa4f 100644 --- a/providers/tests/standard/operators/test_python.py +++ b/providers/tests/standard/operators/test_python.py @@ -64,7 +64,6 @@ get_current_context, ) from airflow.providers.standard.utils.python_virtualenv import prepare_virtualenv -from airflow.settings import _ENABLE_AIP_44 from airflow.utils import timezone from airflow.utils.context import AirflowContextDeprecationWarning, Context from airflow.utils.session import create_session @@ -92,8 +91,6 @@ CLOUDPICKLE_INSTALLED = find_spec("cloudpickle") is not None CLOUDPICKLE_MARKER = pytest.mark.skipif(not CLOUDPICKLE_INSTALLED, reason="`cloudpickle` is not installed") -USE_AIRFLOW_CONTEXT_MARKER = pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is not enabled") - AIRFLOW_CONTEXT_BEFORE_V3_0_MESSAGE = ( r"The `use_airflow_context=True` is only supported in Airflow 3.0.0 and later." ) @@ -1046,7 +1043,6 @@ def f(): task = self.run_as_task(f, env_vars={"MY_ENV_VAR": "EFGHI"}, inherit_env=True) assert task.execute_callable() == "EFGHI" - @USE_AIRFLOW_CONTEXT_MARKER def test_current_context(self): def f(): from airflow.providers.standard.operators.python import get_current_context @@ -1066,7 +1062,6 @@ def f(): with pytest.raises(AirflowException, match=AIRFLOW_CONTEXT_BEFORE_V3_0_MESSAGE): self.run_as_task(f, return_ti=True, use_airflow_context=True) - @USE_AIRFLOW_CONTEXT_MARKER def test_current_context_not_found_error(self): def f(): from airflow.providers.standard.operators.python import get_current_context @@ -1089,7 +1084,6 @@ def f(): ): self.run_as_task(f, return_ti=True, use_airflow_context=False) - @USE_AIRFLOW_CONTEXT_MARKER def test_current_context_airflow_not_found_error(self): airflow_flag: dict[str, bool] = {"expect_airflow": False} error_msg = r"The `use_airflow_context` parameter is set to True, but expect_airflow is set to False." @@ -1116,7 +1110,6 @@ def f(): with pytest.raises(AirflowException, match=AIRFLOW_CONTEXT_BEFORE_V3_0_MESSAGE): self.run_as_task(f, return_ti=True, use_airflow_context=True, **airflow_flag) - @USE_AIRFLOW_CONTEXT_MARKER def test_use_airflow_context_touch_other_variables(self): def f(): from airflow.providers.standard.operators.python import get_current_context @@ -1136,22 +1129,6 @@ def f(): with pytest.raises(AirflowException, match=AIRFLOW_CONTEXT_BEFORE_V3_0_MESSAGE): self.run_as_task(f, return_ti=True, use_airflow_context=True) - @pytest.mark.skipif(_ENABLE_AIP_44, reason="AIP-44 is enabled") - def test_use_airflow_context_without_aip_44_error(self): - def f(): - from airflow.providers.standard.operators.python import get_current_context - - get_current_context() - return [] - - error_msg = "`get_current_context()` needs to be used with AIP-44 enabled." - if AIRFLOW_V_3_0_PLUS: - with pytest.raises(AirflowException, match=re.escape(error_msg)): - self.run_as_task(f, return_ti=True, multiple_outputs=False, use_airflow_context=True) - else: - with pytest.raises(AirflowException, match=re.escape(AIRFLOW_CONTEXT_BEFORE_V3_0_MESSAGE)): - self.run_as_task(f, return_ti=True, use_airflow_context=True) - venv_cache_path = tempfile.mkdtemp(prefix="venv_cache_path") @@ -1528,7 +1505,6 @@ def f( self.run_as_task(f, serializer=serializer, system_site_packages=False, requirements=None) - @USE_AIRFLOW_CONTEXT_MARKER def test_current_context_system_site_packages(self, session): def f(): from airflow.providers.standard.operators.python import get_current_context @@ -1892,7 +1868,6 @@ def default_kwargs(*, python_version=DEFAULT_PYTHON_VERSION, **kwargs): kwargs["venv_cache_path"] = venv_cache_path return kwargs - @USE_AIRFLOW_CONTEXT_MARKER def test_current_context_system_site_packages(self, session): def f(): from airflow.providers.standard.operators.python import get_current_context diff --git a/scripts/ci/docker-compose/devcontainer.env b/scripts/ci/docker-compose/devcontainer.env index 99207f377968b..1d1bd2c310d59 100644 --- a/scripts/ci/docker-compose/devcontainer.env +++ b/scripts/ci/docker-compose/devcontainer.env @@ -17,7 +17,6 @@ HOME= AIRFLOW_CI_IMAGE="ghcr.io/apache/airflow/main/ci/python3.9:latest" ANSWER= -AIRFLOW_ENABLE_AIP_44="true" AIRFLOW_ENV="development" PYTHON_MAJOR_MINOR_VERSION="3.9" AIRFLOW_EXTRAS= diff --git a/scripts/cov/other_coverage.py b/scripts/cov/other_coverage.py index 41c7a4352369d..0394e3590bec3 100644 --- a/scripts/cov/other_coverage.py +++ b/scripts/cov/other_coverage.py @@ -71,7 +71,6 @@ """ Other tests to potentially run against the source_file packages: - "tests/api_internal", "tests/auth", "tests/callbacks", "tests/charts", diff --git a/scripts/cov/restapi_coverage.py b/scripts/cov/restapi_coverage.py index b3750248e396f..af46468b1c7b5 100644 --- a/scripts/cov/restapi_coverage.py +++ b/scripts/cov/restapi_coverage.py @@ -25,7 +25,7 @@ source_files = ["airflow/api_connexion", "airflow/api_internal"] -restapi_files = ["tests/api_connexion", "tests/api_internal"] +restapi_files = ["tests/api_connexion"] files_not_fully_covered = [ "airflow/api_connexion/endpoints/forward_to_fab_endpoint.py", diff --git a/tests/api_internal/__init__.py b/tests/api_internal/__init__.py deleted file mode 100644 index 13a83393a9124..0000000000000 --- a/tests/api_internal/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/tests/api_internal/endpoints/__init__.py b/tests/api_internal/endpoints/__init__.py deleted file mode 100644 index 13a83393a9124..0000000000000 --- a/tests/api_internal/endpoints/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/tests/api_internal/endpoints/test_rpc_api_endpoint.py b/tests/api_internal/endpoints/test_rpc_api_endpoint.py deleted file mode 100644 index 2b3d66103346c..0000000000000 --- a/tests/api_internal/endpoints/test_rpc_api_endpoint.py +++ /dev/null @@ -1,250 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import json -from collections.abc import Generator -from typing import TYPE_CHECKING -from unittest import mock - -import pytest - -from airflow.api_connexion.exceptions import PermissionDenied -from airflow.configuration import conf -from airflow.models.baseoperator import BaseOperator -from airflow.models.connection import Connection -from airflow.models.taskinstance import TaskInstance -from airflow.operators.empty import EmptyOperator -from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic -from airflow.serialization.serialized_objects import BaseSerialization -from airflow.settings import _ENABLE_AIP_44 -from airflow.utils.jwt_signer import JWTSigner -from airflow.utils.state import State -from airflow.www import app - -from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.decorators import dont_initialize_flask_app_submodules - -# Note: Sounds a bit strange to disable internal API tests in isolation mode but... -# As long as the test is modelled to run its own internal API endpoints, it is conflicting -# to the test setup with a dedicated internal API server. -pytestmark = pytest.mark.db_test - -if TYPE_CHECKING: - from flask import Flask - -TEST_METHOD_NAME = "test_method" -TEST_METHOD_WITH_LOG_NAME = "test_method_with_log" - -mock_test_method = mock.MagicMock() - -pytest.importorskip("pydantic", minversion="2.0.0") - - -@pytest.fixture(scope="session") -def minimal_app_for_internal_api() -> Flask: - @dont_initialize_flask_app_submodules( - skip_all_except=[ - "init_appbuilder", - "init_api_internal", - ] - ) - def factory() -> Flask: - with conf_vars({("webserver", "run_internal_api"): "true"}): - return app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore - - return factory() - - -def equals(a, b) -> bool: - return a == b - - -@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled") -class TestRpcApiEndpoint: - @pytest.fixture - def setup_attrs(self, minimal_app_for_internal_api: Flask) -> Generator: - self.app = minimal_app_for_internal_api - self.client = self.app.test_client() # type:ignore - mock_test_method.reset_mock() - mock_test_method.side_effect = None - with mock.patch( - "airflow.api_internal.endpoints.rpc_api_endpoint.initialize_method_map" - ) as mock_initialize_method_map: - mock_initialize_method_map.return_value = { - TEST_METHOD_NAME: mock_test_method, - } - yield mock_initialize_method_map - - @pytest.fixture - def signer(self) -> JWTSigner: - return JWTSigner( - secret_key=conf.get("core", "internal_api_secret_key"), - expiration_time_in_seconds=conf.getint("core", "internal_api_clock_grace", fallback=30), - audience="api", - ) - - def test_initialize_method_map(self): - from airflow.api_internal.endpoints.rpc_api_endpoint import initialize_method_map - - method_map = initialize_method_map() - assert len(method_map) > 69 - - @pytest.mark.parametrize( - "input_params, method_result, result_cmp_func, method_params", - [ - ({}, None, lambda got, _: got == b"", {}), - ({}, "test_me", equals, {}), - ( - BaseSerialization.serialize({"dag_id": 15, "task_id": "fake-task"}), - ("dag_id_15", "fake-task", 1), - equals, - {"dag_id": 15, "task_id": "fake-task"}, - ), - ( - {}, - TaskInstance(task=EmptyOperator(task_id="task"), run_id="run_id", state=State.RUNNING), - lambda a, b: a.model_dump() == TaskInstancePydantic.model_validate(b).model_dump() - and isinstance(a.task, BaseOperator), - {}, - ), - ( - {}, - Connection(conn_id="test_conn", conn_type="http", host="", password=""), - lambda a, b: a.get_uri() == b.get_uri() and a.conn_id == b.conn_id, - {}, - ), - ], - ) - def test_method( - self, input_params, method_result, result_cmp_func, method_params, setup_attrs, signer: JWTSigner - ): - mock_test_method.return_value = method_result - headers = { - "Content-Type": "application/json", - "Accept": "application/json", - "Authorization": signer.generate_signed_token({"method": TEST_METHOD_NAME}), - } - input_data = { - "jsonrpc": "2.0", - "method": TEST_METHOD_NAME, - "params": input_params, - } - response = self.client.post( - "/internal_api/v1/rpcapi", - headers=headers, - data=json.dumps(input_data), - ) - assert response.status_code == 200 - if method_result: - response_data = BaseSerialization.deserialize(json.loads(response.data), use_pydantic_models=True) - else: - response_data = response.data - - assert result_cmp_func(response_data, method_result) - - mock_test_method.assert_called_once_with(**method_params, session=mock.ANY) - - def test_method_with_exception(self, setup_attrs, signer: JWTSigner): - headers = { - "Content-Type": "application/json", - "Accept": "application/json", - "Authorization": signer.generate_signed_token({"method": TEST_METHOD_NAME}), - } - mock_test_method.side_effect = ValueError("Error!!!") - data = {"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": {}} - - response = self.client.post("/internal_api/v1/rpcapi", headers=headers, data=json.dumps(data)) - assert response.status_code == 500 - assert response.data, b"Error executing method: test_method." - mock_test_method.assert_called_once() - - def test_unknown_method(self, setup_attrs, signer: JWTSigner): - UNKNOWN_METHOD = "i-bet-it-does-not-exist" - headers = { - "Content-Type": "application/json", - "Accept": "application/json", - "Authorization": signer.generate_signed_token({"method": UNKNOWN_METHOD}), - } - data = {"jsonrpc": "2.0", "method": UNKNOWN_METHOD, "params": {}} - - response = self.client.post("/internal_api/v1/rpcapi", headers=headers, data=json.dumps(data)) - assert response.status_code == 400 - assert response.data.startswith(b"Unrecognized method: i-bet-it-does-not-exist.") - mock_test_method.assert_not_called() - - def test_invalid_jsonrpc(self, setup_attrs, signer: JWTSigner): - headers = { - "Content-Type": "application/json", - "Accept": "application/json", - "Authorization": signer.generate_signed_token({"method": TEST_METHOD_NAME}), - } - data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": {}} - - response = self.client.post("/internal_api/v1/rpcapi", headers=headers, data=json.dumps(data)) - assert response.status_code == 400 - assert response.data.startswith(b"Expected jsonrpc 2.0 request.") - mock_test_method.assert_not_called() - - def test_missing_token(self, setup_attrs): - mock_test_method.return_value = None - - input_data = { - "jsonrpc": "2.0", - "method": TEST_METHOD_NAME, - "params": {}, - } - with pytest.raises(PermissionDenied, match="Unable to authenticate API via token."): - self.client.post( - "/internal_api/v1/rpcapi", - headers={"Content-Type": "application/json", "Accept": "application/json"}, - data=json.dumps(input_data), - ) - - def test_invalid_token(self, setup_attrs, signer: JWTSigner): - headers = { - "Content-Type": "application/json", - "Accept": "application/json", - "Authorization": signer.generate_signed_token({"method": "WRONG_METHOD_NAME"}), - } - data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": {}} - - with pytest.raises( - PermissionDenied, match="Bad Signature. Please use only the tokens provided by the API." - ): - self.client.post("/internal_api/v1/rpcapi", headers=headers, data=json.dumps(data)) - - def test_missing_accept(self, setup_attrs, signer: JWTSigner): - headers = { - "Content-Type": "application/json", - "Authorization": signer.generate_signed_token({"method": "WRONG_METHOD_NAME"}), - } - data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": {}} - - with pytest.raises(PermissionDenied, match="Expected Accept: application/json"): - self.client.post("/internal_api/v1/rpcapi", headers=headers, data=json.dumps(data)) - - def test_wrong_accept(self, setup_attrs, signer: JWTSigner): - headers = { - "Content-Type": "application/json", - "Accept": "application/html", - "Authorization": signer.generate_signed_token({"method": "WRONG_METHOD_NAME"}), - } - data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": {}} - - with pytest.raises(PermissionDenied, match="Expected Accept: application/json"): - self.client.post("/internal_api/v1/rpcapi", headers=headers, data=json.dumps(data)) diff --git a/tests/api_internal/test_internal_api_call.py b/tests/api_internal/test_internal_api_call.py deleted file mode 100644 index fe817afbfc7c5..0000000000000 --- a/tests/api_internal/test_internal_api_call.py +++ /dev/null @@ -1,327 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -from __future__ import annotations - -import json -from argparse import Namespace -from typing import TYPE_CHECKING -from unittest import mock - -import pytest -import requests -from tenacity import RetryError - -from airflow.__main__ import configure_internal_api -from airflow.api_internal.internal_api_call import InternalApiConfig, internal_api_call -from airflow.configuration import conf -from airflow.models.taskinstance import TaskInstance -from airflow.operators.empty import EmptyOperator -from airflow.serialization.serialized_objects import BaseSerialization -from airflow.settings import _ENABLE_AIP_44 -from airflow.utils.state import State - -from tests_common.test_utils.config import conf_vars - -if TYPE_CHECKING: - from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic - -pytest.importorskip("pydantic", minversion="2.0.0") - - -@pytest.fixture(autouse=True) -def reset_init_api_config(): - InternalApiConfig._use_internal_api = False - InternalApiConfig._internal_api_endpoint = "" - from airflow import settings - - old_engine = settings.engine - old_session = settings.Session - old_conn = settings.SQL_ALCHEMY_CONN - try: - yield - finally: - InternalApiConfig._use_internal_api = False - InternalApiConfig._internal_api_endpoint = "" - settings.engine = old_engine - settings.Session = old_session - settings.SQL_ALCHEMY_CONN = old_conn - - -@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled") -class TestInternalApiConfig: - @conf_vars( - { - ("core", "database_access_isolation"): "false", - ("core", "internal_api_url"): "http://localhost:8888", - ("database", "sql_alchemy_conn"): "none://", - } - ) - def test_get_use_internal_api_disabled(self): - configure_internal_api(Namespace(subcommand="webserver"), conf) - assert InternalApiConfig.get_use_internal_api() is False - - @conf_vars( - { - ("core", "database_access_isolation"): "true", - ("core", "internal_api_url"): "http://localhost:8888", - ("database", "sql_alchemy_conn"): "none://", - } - ) - def test_get_use_internal_api_enabled(self): - configure_internal_api(Namespace(subcommand="dag-processor"), conf) - assert InternalApiConfig.get_use_internal_api() is True - assert InternalApiConfig.get_internal_api_endpoint() == "http://localhost:8888/internal_api/v1/rpcapi" - - @conf_vars( - { - ("core", "database_access_isolation"): "true", - ("core", "internal_api_url"): "http://localhost:8888", - } - ) - def test_force_database_direct_access(self): - InternalApiConfig.set_use_database_access("message") - assert InternalApiConfig.get_use_internal_api() is False - - -@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled") -class TestInternalApiCall: - @staticmethod - @internal_api_call - def fake_method() -> str: - return "local-call" - - @staticmethod - @internal_api_call - def fake_method_with_params(dag_id: str, task_id: int, session) -> str: - return f"local-call-with-params-{dag_id}-{task_id}" - - @classmethod - @internal_api_call - def fake_class_method_with_params(cls, dag_id: str, session) -> str: - return f"local-classmethod-call-with-params-{dag_id}" - - @staticmethod - @internal_api_call - def fake_class_method_with_serialized_params( - ti: TaskInstance | TaskInstancePydantic, - session, - ) -> str: - return f"local-classmethod-call-with-serialized-{ti.task_id}" - - @conf_vars( - { - ("core", "database_access_isolation"): "false", - ("core", "internal_api_url"): "http://localhost:8888", - } - ) - @mock.patch("airflow.api_internal.internal_api_call.requests") - def test_local_call(self, mock_requests): - result = TestInternalApiCall.fake_method() - - assert result == "local-call" - mock_requests.post.assert_not_called() - - @conf_vars( - { - ("core", "database_access_isolation"): "true", - ("core", "internal_api_url"): "http://localhost:8888", - ("database", "sql_alchemy_conn"): "none://", - } - ) - @mock.patch("airflow.api_internal.internal_api_call.requests") - def test_remote_call(self, mock_requests): - configure_internal_api(Namespace(subcommand="dag-processor"), conf) - response = requests.Response() - response.status_code = 200 - - response._content = json.dumps(BaseSerialization.serialize("remote-call")) - - mock_requests.post.return_value = response - - result = TestInternalApiCall.fake_method() - assert result == "remote-call" - expected_data = json.dumps( - { - "jsonrpc": "2.0", - "method": "tests.api_internal.test_internal_api_call.TestInternalApiCall.fake_method", - "params": BaseSerialization.serialize({}), - } - ) - mock_requests.post.assert_called_once() - call_kwargs: dict = mock_requests.post.call_args.kwargs - assert call_kwargs["url"] == "http://localhost:8888/internal_api/v1/rpcapi" - assert call_kwargs["data"] == expected_data - assert call_kwargs["headers"]["Content-Type"] == "application/json" - assert "Authorization" in call_kwargs["headers"] - - @conf_vars( - { - ("core", "database_access_isolation"): "true", - ("core", "internal_api_url"): "http://localhost:8888", - ("database", "sql_alchemy_conn"): "none://", - } - ) - @mock.patch("airflow.api_internal.internal_api_call.requests") - def test_remote_call_with_none_result(self, mock_requests): - configure_internal_api(Namespace(subcommand="dag-processor"), conf) - response = requests.Response() - response.status_code = 200 - response._content = b"" - - mock_requests.post.return_value = response - - result = TestInternalApiCall.fake_method() - assert result is None - - @conf_vars( - { - ("core", "database_access_isolation"): "true", - ("core", "internal_api_url"): "http://localhost:8888", - ("database", "sql_alchemy_conn"): "none://", - } - ) - @mock.patch("airflow.api_internal.internal_api_call.requests") - def test_remote_call_with_params(self, mock_requests): - configure_internal_api(Namespace(subcommand="dag-processor"), conf) - response = requests.Response() - response.status_code = 200 - - response._content = json.dumps(BaseSerialization.serialize("remote-call")) - - mock_requests.post.return_value = response - - result = TestInternalApiCall.fake_method_with_params("fake-dag", task_id=123, session="session") - - assert result == "remote-call" - expected_data = json.dumps( - { - "jsonrpc": "2.0", - "method": "tests.api_internal.test_internal_api_call.TestInternalApiCall." - "fake_method_with_params", - "params": BaseSerialization.serialize( - { - "dag_id": "fake-dag", - "task_id": 123, - } - ), - } - ) - mock_requests.post.assert_called_once() - call_kwargs: dict = mock_requests.post.call_args.kwargs - assert call_kwargs["url"] == "http://localhost:8888/internal_api/v1/rpcapi" - assert call_kwargs["data"] == expected_data - assert call_kwargs["headers"]["Content-Type"] == "application/json" - assert "Authorization" in call_kwargs["headers"] - - @conf_vars( - { - ("core", "database_access_isolation"): "true", - ("core", "internal_api_url"): "http://localhost:8888", - ("database", "sql_alchemy_conn"): "none://", - } - ) - @mock.patch("airflow.api_internal.internal_api_call.requests") - def test_remote_classmethod_call_with_params(self, mock_requests): - configure_internal_api(Namespace(subcommand="dag-processor"), conf) - response = requests.Response() - response.status_code = 200 - - response._content = json.dumps(BaseSerialization.serialize("remote-call")) - - mock_requests.post.return_value = response - - result = TestInternalApiCall.fake_class_method_with_params("fake-dag", session="session") - - assert result == "remote-call" - expected_data = json.dumps( - { - "jsonrpc": "2.0", - "method": "tests.api_internal.test_internal_api_call.TestInternalApiCall." - "fake_class_method_with_params", - "params": BaseSerialization.serialize( - { - "dag_id": "fake-dag", - } - ), - } - ) - mock_requests.post.assert_called_once() - call_kwargs: dict = mock_requests.post.call_args.kwargs - assert call_kwargs["url"] == "http://localhost:8888/internal_api/v1/rpcapi" - assert call_kwargs["data"] == expected_data - assert call_kwargs["headers"]["Content-Type"] == "application/json" - assert "Authorization" in call_kwargs["headers"] - - @conf_vars( - { - ("core", "database_access_isolation"): "true", - ("core", "internal_api_url"): "http://localhost:8888", - ("database", "sql_alchemy_conn"): "none://", - } - ) - @mock.patch("airflow.api_internal.internal_api_call.requests") - @mock.patch("tenacity.time.sleep") - def test_retry_on_bad_gateway(self, mock_sleep, mock_requests): - configure_internal_api(Namespace(subcommand="dag-processor"), conf) - response = requests.Response() - response.status_code = 502 - response.reason = "Bad Gateway" - response._content = b"Bad Gateway" - - mock_sleep = lambda *_, **__: None # noqa: F841 - mock_requests.post.return_value = response - with pytest.raises(RetryError): - TestInternalApiCall.fake_method_with_params("fake-dag", task_id=123, session="session") - assert mock_requests.post.call_count == 10 - - @conf_vars( - { - ("core", "database_access_isolation"): "true", - ("core", "internal_api_url"): "http://localhost:8888", - ("database", "sql_alchemy_conn"): "none://", - } - ) - @mock.patch("airflow.api_internal.internal_api_call.requests") - def test_remote_call_with_serialized_model(self, mock_requests): - configure_internal_api(Namespace(subcommand="dag-processor"), conf) - response = requests.Response() - response.status_code = 200 - - response._content = json.dumps(BaseSerialization.serialize("remote-call")) - - mock_requests.post.return_value = response - ti = TaskInstance(task=EmptyOperator(task_id="task"), run_id="run_id", state=State.RUNNING) - - result = TestInternalApiCall.fake_class_method_with_serialized_params(ti, session="session") - - assert result == "remote-call" - expected_data = json.dumps( - { - "jsonrpc": "2.0", - "method": "tests.api_internal.test_internal_api_call.TestInternalApiCall." - "fake_class_method_with_serialized_params", - "params": BaseSerialization.serialize({"ti": ti}, use_pydantic_models=True), - } - ) - mock_requests.post.assert_called_once() - call_kwargs: dict = mock_requests.post.call_args.kwargs - assert call_kwargs["url"] == "http://localhost:8888/internal_api/v1/rpcapi" - assert call_kwargs["data"] == expected_data - assert call_kwargs["headers"]["Content-Type"] == "application/json" - assert "Authorization" in call_kwargs["headers"] diff --git a/tests/cli/commands/test_internal_api_command.py b/tests/cli/commands/test_internal_api_command.py deleted file mode 100644 index 329e8bfbe5336..0000000000000 --- a/tests/cli/commands/test_internal_api_command.py +++ /dev/null @@ -1,223 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import os -import subprocess -import sys -import time -from unittest import mock - -import psutil -import pytest -from rich.console import Console - -from airflow import settings -from airflow.cli import cli_parser -from airflow.cli.commands import internal_api_command -from airflow.cli.commands.internal_api_command import GunicornMonitor -from airflow.settings import _ENABLE_AIP_44 - -from tests.cli.commands._common_cli_classes import _CommonCLIGunicornTestClass -from tests_common.test_utils.config import conf_vars - -console = Console(width=400, color_system="standard") - - -class TestCLIGetNumReadyWorkersRunning: - @classmethod - def setup_class(cls): - cls.parser = cli_parser.get_parser() - - def setup_method(self): - self.children = mock.MagicMock() - self.child = mock.MagicMock() - self.process = mock.MagicMock() - self.monitor = GunicornMonitor( - gunicorn_master_pid=1, - num_workers_expected=4, - master_timeout=60, - worker_refresh_interval=60, - worker_refresh_batch_size=2, - reload_on_plugin_change=True, - ) - - def test_ready_prefix_on_cmdline(self): - self.child.cmdline.return_value = [settings.GUNICORN_WORKER_READY_PREFIX] - self.process.children.return_value = [self.child] - - with mock.patch("psutil.Process", return_value=self.process): - assert self.monitor._get_num_ready_workers_running() == 1 - - def test_ready_prefix_on_cmdline_no_children(self): - self.process.children.return_value = [] - - with mock.patch("psutil.Process", return_value=self.process): - assert self.monitor._get_num_ready_workers_running() == 0 - - def test_ready_prefix_on_cmdline_zombie(self): - self.child.cmdline.return_value = [] - self.process.children.return_value = [self.child] - - with mock.patch("psutil.Process", return_value=self.process): - assert self.monitor._get_num_ready_workers_running() == 0 - - def test_ready_prefix_on_cmdline_dead_process(self): - self.child.cmdline.side_effect = psutil.NoSuchProcess(11347) - self.process.children.return_value = [self.child] - - with mock.patch("psutil.Process", return_value=self.process): - assert self.monitor._get_num_ready_workers_running() == 0 - - -@pytest.mark.db_test -@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled") -class TestCliInternalAPI(_CommonCLIGunicornTestClass): - main_process_regexp = r"airflow internal-api" - - @pytest.mark.execution_timeout(210) - def test_cli_internal_api_background(self, tmp_path): - parent_path = tmp_path / "gunicorn" - parent_path.mkdir() - pidfile_internal_api = parent_path / "pidflow-internal-api.pid" - pidfile_monitor = parent_path / "pidflow-internal-api-monitor.pid" - stdout = parent_path / "airflow-internal-api.out" - stderr = parent_path / "airflow-internal-api.err" - logfile = parent_path / "airflow-internal-api.log" - try: - # Run internal-api as daemon in background. Note that the wait method is not called. - console.print("[magenta]Starting airflow internal-api --daemon") - env = os.environ.copy() - env["AIRFLOW__CORE__DATABASE_ACCESS_ISOLATION"] = "true" - proc = subprocess.Popen( - [ - "airflow", - "internal-api", - "--daemon", - "--pid", - os.fspath(pidfile_internal_api), - "--stdout", - os.fspath(stdout), - "--stderr", - os.fspath(stderr), - "--log-file", - os.fspath(logfile), - ], - env=env, - ) - assert proc.poll() is None - - pid_monitor = self._wait_pidfile(pidfile_monitor) - console.print(f"[blue]Monitor started at {pid_monitor}") - pid_internal_api = self._wait_pidfile(pidfile_internal_api) - console.print(f"[blue]Internal API started at {pid_internal_api}") - console.print("[blue]Running airflow internal-api process:") - # Assert that the internal-api and gunicorn processes are running (by name rather than pid). - assert self._find_process(r"airflow internal-api --daemon", print_found_process=True) - console.print("[blue]Waiting for gunicorn processes:") - # wait for gunicorn to start - for _ in range(30): - if self._find_process(r"^gunicorn"): - break - console.print("[blue]Waiting for gunicorn to start ...") - time.sleep(1) - console.print("[blue]Running gunicorn processes:") - assert self._find_all_processes("^gunicorn", print_found_process=True) - console.print("[magenta]Internal-api process started successfully.") - console.print( - "[magenta]Terminating monitor process and expect " - "internal-api and gunicorn processes to terminate as well" - ) - self._terminate_multiple_process([pid_internal_api, pid_monitor]) - self._check_processes(ignore_running=False) - console.print("[magenta]All internal-api and gunicorn processes are terminated.") - except Exception: - console.print("[red]Exception occurred. Dumping all logs.") - # Dump all logs - for file in parent_path.glob("*"): - console.print(f"Dumping {file} (size: {file.stat().st_size})") - console.print(file.read_text()) - raise - - @conf_vars({("core", "database_access_isolation"): "true"}) - def test_cli_internal_api_debug(self, app): - with ( - mock.patch("airflow.cli.commands.internal_api_command.create_app", return_value=app), - mock.patch.object(app, "run") as app_run, - ): - args = self.parser.parse_args( - [ - "internal-api", - "--debug", - ] - ) - internal_api_command.internal_api(args) - - app_run.assert_called_with( - debug=True, - use_reloader=False, - port=9080, - host="0.0.0.0", - ) - - @conf_vars({("core", "database_access_isolation"): "true"}) - def test_cli_internal_api_args(self): - with ( - mock.patch("subprocess.Popen") as Popen, - mock.patch.object(internal_api_command, "GunicornMonitor"), - ): - args = self.parser.parse_args( - [ - "internal-api", - "--access-logformat", - "custom_log_format", - "--pid", - "/tmp/x.pid", - ] - ) - internal_api_command.internal_api(args) - - Popen.assert_called_with( - [ - sys.executable, - "-m", - "gunicorn", - "--workers", - "4", - "--worker-class", - "sync", - "--timeout", - "120", - "--bind", - "0.0.0.0:9080", - "--name", - "airflow-internal-api", - "--pid", - "/tmp/x.pid", - "--access-logfile", - "-", - "--error-logfile", - "-", - "--config", - "python:airflow.api_internal.gunicorn_config", - "--access-logformat", - "custom_log_format", - "airflow.cli.commands.internal_api_command:cached_app()", - "--preload", - ], - close_fds=True, - ) diff --git a/tests/core/test_settings.py b/tests/core/test_settings.py index 8afab5ad03a3f..244be2fc9ee16 100644 --- a/tests/core/test_settings.py +++ b/tests/core/test_settings.py @@ -26,7 +26,6 @@ import pytest -from airflow.api_internal.internal_api_call import InternalApiConfig from airflow.exceptions import AirflowClusterPolicyViolation, AirflowConfigException from airflow.settings import is_usage_data_collection_enabled @@ -65,25 +64,6 @@ def task_must_have_owners(task: BaseOperator): """ -@pytest.fixture -def clear_internal_api(): - InternalApiConfig._use_internal_api = False - InternalApiConfig._internal_api_endpoint = "" - from airflow import settings - - old_engine = settings.engine - old_session = settings.Session - old_conn = settings.SQL_ALCHEMY_CONN - try: - yield - finally: - InternalApiConfig._use_internal_api = False - InternalApiConfig._internal_api_endpoint = "" - settings.engine = old_engine - settings.Session = old_session - settings.SQL_ALCHEMY_CONN = old_conn - - class SettingsContext: def __init__(self, content: str, module_name: str): self.content = content @@ -328,7 +308,7 @@ def test_encoding_absent_in_v2(is_v1, mock_conf): (None, "False", False), # Default env, conf disables ], ) -def test_usage_data_collection_disabled(env_var, conf_setting, is_enabled, clear_internal_api): +def test_usage_data_collection_disabled(env_var, conf_setting, is_enabled): conf_patch = conf_vars({("usage_data_collection", "enabled"): conf_setting}) if env_var is not None: diff --git a/tests/core/test_sqlalchemy_config.py b/tests/core/test_sqlalchemy_config.py index c0ac11502a41a..bbf5974e621d2 100644 --- a/tests/core/test_sqlalchemy_config.py +++ b/tests/core/test_sqlalchemy_config.py @@ -23,7 +23,6 @@ from sqlalchemy.pool import NullPool from airflow import settings -from airflow.api_internal.internal_api_call import InternalApiConfig from airflow.exceptions import AirflowConfigException from tests_common.test_utils.config import conf_vars @@ -38,16 +37,12 @@ def setup_method(self): self.old_engine = settings.engine self.old_session = settings.Session self.old_conn = settings.SQL_ALCHEMY_CONN - InternalApiConfig._use_internal_api = False - InternalApiConfig._internal_api_endpoint = "" settings.SQL_ALCHEMY_CONN = "mysql+foobar://user:pass@host/dbname?inline=param&another=param" def teardown_method(self): settings.engine = self.old_engine settings.Session = self.old_session settings.SQL_ALCHEMY_CONN = self.old_conn - InternalApiConfig._use_internal_api = False - InternalApiConfig._internal_api_endpoint = "" @patch("airflow.settings.setup_event_handlers") @patch("airflow.settings.scoped_session") diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 25b94e7bab1dd..179fe8f46f4ff 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -80,7 +80,6 @@ from airflow.sdk.definitions.asset import AssetAlias from airflow.sensors.base import BaseSensorOperator from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG -from airflow.settings import TracebackSessionForTests from airflow.stats import Stats from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS @@ -1544,9 +1543,7 @@ def do_something_else(i): dep_context=DepContext(flag_upstream_failed=flag_upstream_failed), session=session, ) - TracebackSessionForTests.set_allow_db_access(session, True) completed = all(dep.passed for dep in dep_results) - TracebackSessionForTests.set_allow_db_access(session, False) ti = dr.get_task_instance(task_id="do_something_else", map_index=3, session=session) assert completed == expect_completed diff --git a/tests/operators/test_trigger_dagrun.py b/tests/operators/test_trigger_dagrun.py index a631e77c0f0a2..a8a6b3c262903 100644 --- a/tests/operators/test_trigger_dagrun.py +++ b/tests/operators/test_trigger_dagrun.py @@ -32,7 +32,6 @@ from airflow.models.taskinstance import TaskInstance from airflow.providers.standard.operators.trigger_dagrun import TriggerDagRunOperator from airflow.providers.standard.triggers.external_task import DagStateTrigger -from airflow.settings import TracebackSessionForTests from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import DagRunState, State, TaskInstanceState @@ -73,11 +72,9 @@ def setup_method(self): session.commit() def re_sync_triggered_dag_to_db(self, dag, dag_maker): - TracebackSessionForTests.set_allow_db_access(dag_maker.session, True) dagbag = DagBag(self.f_name, read_dags_from_db=False, include_examples=False) dagbag.bag_dag(dag) dagbag.sync_to_db(session=dag_maker.session) - TracebackSessionForTests.set_allow_db_access(dag_maker.session, False) def teardown_method(self): """Cleanup state after testing in DB.""" diff --git a/tests/serialization/test_pydantic_models.py b/tests/serialization/test_pydantic_models.py deleted file mode 100644 index eb5504734b54c..0000000000000 --- a/tests/serialization/test_pydantic_models.py +++ /dev/null @@ -1,295 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import datetime - -import pytest -from dateutil import relativedelta - -from airflow.decorators import task -from airflow.decorators.python import _PythonDecoratedOperator -from airflow.jobs.job import Job -from airflow.jobs.local_task_job_runner import LocalTaskJobRunner -from airflow.models import MappedOperator -from airflow.models.asset import ( - AssetEvent, - AssetModel, - DagScheduleAssetReference, - TaskOutletAssetReference, -) -from airflow.models.dag import DAG, DagModel -from airflow.serialization.pydantic.asset import AssetEventPydantic -from airflow.serialization.pydantic.dag import DagModelPydantic -from airflow.serialization.pydantic.dag_run import DagRunPydantic -from airflow.serialization.pydantic.job import JobPydantic -from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic -from airflow.serialization.serialized_objects import BaseSerialization -from airflow.settings import _ENABLE_AIP_44, TracebackSessionForTests -from airflow.utils import timezone -from airflow.utils.state import State -from airflow.utils.types import AttributeRemoved, DagRunType - -from tests.models import DEFAULT_DATE -from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS - -if AIRFLOW_V_3_0_PLUS: - from airflow.utils.types import DagRunTriggeredByType - -pytestmark = pytest.mark.db_test - -pytest.importorskip("pydantic", minversion="2.0.0") - - -@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled") -def test_serializing_pydantic_task_instance(session, create_task_instance): - dag_id = "test-dag" - ti = create_task_instance(dag_id=dag_id, session=session) - ti.state = State.RUNNING - ti.next_kwargs = {"foo": "bar"} - session.commit() - - pydantic_task_instance = TaskInstancePydantic.model_validate(ti) - - json_string = pydantic_task_instance.model_dump_json() - print(json_string) - - deserialized_model = TaskInstancePydantic.model_validate_json(json_string) - assert deserialized_model.dag_id == dag_id - assert deserialized_model.state == State.RUNNING - assert deserialized_model.try_number == ti.try_number - assert deserialized_model.logical_date == ti.logical_date - assert deserialized_model.next_kwargs == {"foo": "bar"} - - -@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled") -def test_deserialize_ti_mapped_op_reserialized_with_refresh_from_task(session, dag_maker): - op_class_dict_expected = { - "_needs_expansion": True, - "task_type": "_PythonDecoratedOperator", - "downstream_task_ids": [], - "start_from_trigger": False, - "start_trigger_args": None, - "ui_fgcolor": "#000", - "ui_color": "#ffefeb", - "template_fields": ["templates_dict", "op_args", "op_kwargs"], - "template_fields_renderers": {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"}, - "template_ext": [], - "task_id": "target", - } - - with dag_maker() as dag: - - @task - def source(): - return [1, 2, 3] - - @task - def target(val=None): - print(val) - - # source() >> target() - target.expand(val=source()) - dr = dag_maker.create_dagrun() - ti = dr.task_instances[1] - - # roundtrip task - ser_task = BaseSerialization.serialize(ti.task, use_pydantic_models=True) - deser_task = BaseSerialization.deserialize(ser_task, use_pydantic_models=True) - ti.task.operator_class - # this is part of the problem! - assert isinstance(ti.task.operator_class, type) - assert isinstance(deser_task.operator_class, dict) - - assert ti.task.operator_class == _PythonDecoratedOperator - ti.refresh_from_task(deser_task) - # roundtrip ti - sered = BaseSerialization.serialize(ti, use_pydantic_models=True) - desered = BaseSerialization.deserialize(sered, use_pydantic_models=True) - assert desered.task.dag.__class__ is AttributeRemoved - assert "operator_class" not in sered["__var"]["task"] - - assert desered.task.__class__ == MappedOperator - - assert desered.task.operator_class == op_class_dict_expected - assert desered.task.task_type == "_PythonDecoratedOperator" - assert desered.task.operator_name == "@task" - - desered.refresh_from_task(deser_task) - - assert desered.task.__class__ == MappedOperator - - assert isinstance(desered.task.operator_class, dict) - - # let's check that we can safely add back dag... - assert isinstance(dag, DAG) - # dag already has this task - assert dag.has_task(desered.task.task_id) is True - # but the task has no dag - assert desered.task.dag.__class__ is AttributeRemoved - # and there are no upstream / downstreams on the task cus those are wiped out on serialization - # and this is wrong / not great but that's how it is - assert desered.task.upstream_task_ids == set() - assert desered.task.downstream_task_ids == set() - # add the dag back - desered.task.dag = dag - # great, no error - # but still, there are no upstream downstreams - assert desered.task.upstream_task_ids == set() - assert desered.task.downstream_task_ids == set() - - -@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled") -def test_serializing_pydantic_dagrun(session, create_task_instance): - dag_id = "test-dag" - ti = create_task_instance(dag_id=dag_id, session=session) - ti.dag_run.state = State.RUNNING - session.commit() - - pydantic_dag_run = DagRunPydantic.model_validate(ti.dag_run) - - json_string = pydantic_dag_run.model_dump_json() - print(json_string) - - deserialized_model = DagRunPydantic.model_validate_json(json_string) - assert deserialized_model.dag_id == dag_id - assert deserialized_model.state == State.RUNNING - - -@pytest.mark.parametrize( - "schedule", - [ - None, - "*/10 * * *", - datetime.timedelta(days=1), - relativedelta.relativedelta(days=+12), - ], -) -def test_serializing_pydantic_dagmodel(schedule): - dag_model = DagModel( - dag_id="test-dag", - fileloc="/tmp/dag_1.py", - timetable_summary="summary", - timetable_description="desc", - is_active=True, - is_paused=False, - ) - - pydantic_dag_model = DagModelPydantic.model_validate(dag_model) - json_string = pydantic_dag_model.model_dump_json() - - deserialized_model = DagModelPydantic.model_validate_json(json_string) - assert deserialized_model.dag_id == "test-dag" - assert deserialized_model.fileloc == "/tmp/dag_1.py" - assert deserialized_model.timetable_summary == "summary" - assert deserialized_model.timetable_description == "desc" - assert deserialized_model.is_active is True - assert deserialized_model.is_paused is False - - -def test_serializing_pydantic_local_task_job(session, create_task_instance): - dag_id = "test-dag" - ti = create_task_instance(dag_id=dag_id, session=session) - ltj = Job(dag_id=ti.dag_id) - LocalTaskJobRunner(job=ltj, task_instance=ti) - ltj.state = State.RUNNING - session.commit() - pydantic_job = JobPydantic.model_validate(ltj) - - json_string = pydantic_job.model_dump_json() - - deserialized_model = JobPydantic.model_validate_json(json_string) - assert deserialized_model.dag_id == dag_id - assert deserialized_model.job_type == "LocalTaskJob" - assert deserialized_model.state == State.RUNNING - - -@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled") -def test_serializing_pydantic_asset_event(session, create_task_instance, create_dummy_dag): - ds1 = AssetModel(id=1, uri="one", extra={"foo": "bar"}) - ds2 = AssetModel(id=2, uri="two") - - session.add_all([ds1, ds2]) - session.commit() - - # it's easier to fake a manual run here - dag, task1 = create_dummy_dag( - dag_id="test_triggering_asset_events", - schedule=None, - start_date=DEFAULT_DATE, - task_id="test_context", - with_dagrun_type=DagRunType.MANUAL, - session=session, - ) - logical_date = timezone.utcnow() - TracebackSessionForTests.set_allow_db_access(session, True) - - triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} - dr = dag.create_dagrun( - run_id="test2", - run_type=DagRunType.ASSET_TRIGGERED, - logical_date=logical_date, - state=None, - session=session, - data_interval=(logical_date, logical_date), - **triggered_by_kwargs, - ) - asset1_event = AssetEvent(asset_id=1) - asset2_event_1 = AssetEvent(asset_id=2) - asset2_event_2 = AssetEvent(asset_id=2) - - dag_asset_ref = DagScheduleAssetReference(dag_id=dag.dag_id) - session.add(dag_asset_ref) - dag_asset_ref.asset = ds1 - task_ds_ref = TaskOutletAssetReference(task_id=task1.task_id, dag_id=dag.dag_id) - session.add(task_ds_ref) - task_ds_ref.asset = ds1 - - dr.consumed_asset_events.append(asset1_event) - dr.consumed_asset_events.append(asset2_event_1) - dr.consumed_asset_events.append(asset2_event_2) - session.commit() - TracebackSessionForTests.set_allow_db_access(session, False) - - print(asset2_event_2.asset.consuming_dags) - pydantic_dse1 = AssetEventPydantic.model_validate(asset1_event) - json_string1 = pydantic_dse1.model_dump_json() - print(json_string1) - - pydantic_dse2 = AssetEventPydantic.model_validate(asset2_event_1) - json_string2 = pydantic_dse2.model_dump_json() - print(json_string2) - - pydantic_dag_run = DagRunPydantic.model_validate(dr) - json_string_dr = pydantic_dag_run.model_dump_json() - print(json_string_dr) - - deserialized_model1 = AssetEventPydantic.model_validate_json(json_string1) - assert deserialized_model1.asset.id == 1 - assert deserialized_model1.asset.uri == "one" - assert len(deserialized_model1.asset.consuming_dags) == 1 - assert len(deserialized_model1.asset.producing_tasks) == 1 - - deserialized_model2 = AssetEventPydantic.model_validate_json(json_string2) - assert deserialized_model2.asset.id == 2 - assert deserialized_model2.asset.uri == "two" - assert len(deserialized_model2.asset.consuming_dags) == 0 - assert len(deserialized_model2.asset.producing_tasks) == 0 - - deserialized_dr = DagRunPydantic.model_validate_json(json_string_dr) - assert len(deserialized_dr.consumed_asset_events) == 3 diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index 7eb82d08f19cf..75ff736be8733 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -58,7 +58,6 @@ from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic from airflow.serialization.pydantic.tasklog import LogTemplatePydantic from airflow.serialization.serialized_objects import BaseSerialization -from airflow.settings import _ENABLE_AIP_44 from airflow.triggers.base import BaseTrigger from airflow.utils import timezone from airflow.utils.context import OutletEventAccessor, OutletEventAccessors @@ -332,7 +331,6 @@ def test_backcompat_deserialize_connection(conn_uri): } -@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled") @pytest.mark.parametrize( "input, pydantic_class, encoded_type, cmp_func", [ @@ -415,8 +413,6 @@ def test_serialize_deserialize_pydantic(input, pydantic_class, encoded_type, cmp def test_all_pydantic_models_round_trip(): pytest.importorskip("pydantic", minversion="2.0.0") - if not _ENABLE_AIP_44: - pytest.skip("AIP-44 is disabled") classes = set() mods_folder = REPO_ROOT / "airflow/serialization/pydantic" for p in mods_folder.iterdir(): diff --git a/tests_common/pytest_plugin.py b/tests_common/pytest_plugin.py index 692c2ff0e9d9d..b0b76e8cf9cb0 100644 --- a/tests_common/pytest_plugin.py +++ b/tests_common/pytest_plugin.py @@ -142,7 +142,6 @@ os.environ["AIRFLOW__CORE__UNIT_TEST_MODE"] = "True" os.environ["AWS_DEFAULT_REGION"] = os.environ.get("AWS_DEFAULT_REGION") or "us-east-1" os.environ["CREDENTIALS_DIR"] = os.environ.get("CREDENTIALS_DIR") or "/files/airflow-breeze-config/keys" -os.environ["AIRFLOW_ENABLE_AIP_44"] = os.environ.get("AIRFLOW_ENABLE_AIP_44") or "true" if platform.system() == "Darwin": # mocks from unittest.mock work correctly in subprocesses only if they are created by "fork" method diff --git a/tests_common/test_utils/compat.py b/tests_common/test_utils/compat.py index 277b4f9c1012b..5ec9c4edab433 100644 --- a/tests_common/test_utils/compat.py +++ b/tests_common/test_utils/compat.py @@ -44,11 +44,6 @@ AIRFLOW_V_2_10_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("2.10.0") AIRFLOW_V_3_0_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("3.0.0") -if AIRFLOW_V_3_0_PLUS: - os.environ["AIRFLOW_ENABLE_AIP_44"] = os.environ.get("AIRFLOW_ENABLE_AIP_44", "true") -else: - os.environ["AIRFLOW_ENABLE_AIP_44"] = "false" - try: from airflow.models.baseoperatorlink import BaseOperatorLink except ImportError: