diff --git a/airflow/api_fastapi/common/db/common.py b/airflow/api_fastapi/common/db/common.py index 17da1eafacc93..2d7da4bff7376 100644 --- a/airflow/api_fastapi/common/db/common.py +++ b/airflow/api_fastapi/common/db/common.py @@ -24,8 +24,10 @@ from typing import TYPE_CHECKING, Literal, Sequence, overload -from airflow.utils.db import get_query_count -from airflow.utils.session import NEW_SESSION, create_session, provide_session +from sqlalchemy.ext.asyncio import AsyncSession + +from airflow.utils.db import get_query_count, get_query_count_async +from airflow.utils.session import NEW_SESSION, create_session, create_session_async, provide_session if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -53,7 +55,9 @@ def your_route(session: Annotated[Session, Depends(get_session)]): def apply_filters_to_select( - *, base_select: Select, filters: Sequence[BaseParam | None] | None = None + *, + base_select: Select, + filters: Sequence[BaseParam | None] | None = None, ) -> Select: if filters is None: return base_select @@ -65,6 +69,80 @@ def apply_filters_to_select( return base_select +async def get_async_session() -> AsyncSession: + """ + Dependency for providing a session. + + Example usage: + + .. code:: python + + @router.get("/your_path") + def your_route(session: Annotated[AsyncSession, Depends(get_async_session)]): + pass + """ + async with create_session_async() as session: + yield session + + +@overload +async def paginated_select_async( + *, + query: Select, + filters: Sequence[BaseParam] | None = None, + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, + session: AsyncSession, + return_total_entries: Literal[True] = True, +) -> tuple[Select, int]: ... + + +@overload +async def paginated_select_async( + *, + query: Select, + filters: Sequence[BaseParam] | None = None, + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, + session: AsyncSession, + return_total_entries: Literal[False], +) -> tuple[Select, None]: ... + + +async def paginated_select_async( + *, + query: Select, + filters: Sequence[BaseParam | None] | None = None, + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, + session: AsyncSession, + return_total_entries: bool = True, +) -> tuple[Select, int | None]: + query = apply_filters_to_select( + base_select=query, + filters=filters, + ) + + total_entries = None + if return_total_entries: + total_entries = await get_query_count_async(query, session=session) + + # TODO: Re-enable when permissions are handled. Readable / writable entities, + # for instance: + # readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user) + # dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags)) + + query = apply_filters_to_select( + base_select=query, + filters=[order_by, offset, limit], + ) + + return query, total_entries + + @overload def paginated_select( *, diff --git a/airflow/api_fastapi/core_api/routes/public/backfills.py b/airflow/api_fastapi/core_api/routes/public/backfills.py index aa6f540d32791..78b2beb558895 100644 --- a/airflow/api_fastapi/core_api/routes/public/backfills.py +++ b/airflow/api_fastapi/core_api/routes/public/backfills.py @@ -20,9 +20,10 @@ from fastapi import Depends, HTTPException, status from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from airflow.api_fastapi.common.db.common import get_session, paginated_select +from airflow.api_fastapi.common.db.common import get_async_session, get_session, paginated_select_async from airflow.api_fastapi.common.parameters import QueryLimit, QueryOffset, SortParam from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.core_api.datamodels.backfills import ( @@ -49,7 +50,7 @@ @backfills_router.get( path="", ) -def list_backfills( +async def list_backfills( dag_id: str, limit: QueryLimit, offset: QueryOffset, @@ -57,18 +58,16 @@ def list_backfills( SortParam, Depends(SortParam(["id"], Backfill).dynamic_depends()), ], - session: Annotated[Session, Depends(get_session)], + session: Annotated[AsyncSession, Depends(get_async_session)], ) -> BackfillCollectionResponse: - select_stmt, total_entries = paginated_select( - select=select(Backfill).where(Backfill.dag_id == dag_id), + select_stmt, total_entries = await paginated_select_async( + query=select(Backfill).where(Backfill.dag_id == dag_id), order_by=order_by, offset=offset, limit=limit, session=session, ) - - backfills = session.scalars(select_stmt) - + backfills = await session.scalars(select_stmt) return BackfillCollectionResponse( backfills=backfills, total_entries=total_entries, diff --git a/airflow/settings.py b/airflow/settings.py index 5b458efcba473..76b3e948964f3 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -31,7 +31,7 @@ import pluggy from packaging.version import Version from sqlalchemy import create_engine, exc, text -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession as SAAsyncSession, create_async_engine from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.pool import NullPool @@ -111,7 +111,7 @@ # this is achieved by the Session factory above. NonScopedSession: Callable[..., SASession] async_engine: AsyncEngine -create_async_session: Callable[..., AsyncSession] +AsyncSession: Callable[..., SAAsyncSession] # The JSON library to use for DAG Serialization and De-Serialization json = json @@ -469,7 +469,7 @@ def configure_orm(disable_connection_pool=False, pool_class=None): global Session global engine global async_engine - global create_async_session + global AsyncSession global NonScopedSession if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true": @@ -498,11 +498,11 @@ def configure_orm(disable_connection_pool=False, pool_class=None): engine = create_engine(SQL_ALCHEMY_CONN, connect_args=connect_args, **engine_args, future=True) async_engine = create_async_engine(SQL_ALCHEMY_CONN_ASYNC, future=True) - create_async_session = sessionmaker( + AsyncSession = sessionmaker( bind=async_engine, autocommit=False, autoflush=False, - class_=AsyncSession, + class_=SAAsyncSession, expire_on_commit=False, ) mask_secret(engine.url.password) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index d8939a117317f..c899ebf615d06 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -70,6 +70,7 @@ from alembic.runtime.environment import EnvironmentContext from alembic.script import ScriptDirectory from sqlalchemy.engine import Row + from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from sqlalchemy.sql.elements import ClauseElement, TextClause from sqlalchemy.sql.selectable import Select @@ -1447,6 +1448,21 @@ def get_query_count(query_stmt: Select, *, session: Session) -> int: return session.scalar(count_stmt) +async def get_query_count_async(query: Select, *, session: AsyncSession) -> int: + """ + Get count of a query. + + A SELECT COUNT() FROM is issued against the subquery built from the + given statement. The ORDER BY clause is stripped from the statement + since it's unnecessary for COUNT, and can impact query planning and + degrade performance. + + :meta private: + """ + count_stmt = select(func.count()).select_from(query.order_by(None).subquery()) + return await session.scalar(count_stmt) + + def check_query_exists(query_stmt: Select, *, session: Session) -> bool: """ Check whether there is at least one row matching a query. diff --git a/airflow/utils/session.py b/airflow/utils/session.py index a63d3f3f937a8..49383cdf4a8bf 100644 --- a/airflow/utils/session.py +++ b/airflow/utils/session.py @@ -65,6 +65,24 @@ def create_session(scoped: bool = True) -> Generator[SASession, None, None]: session.close() +@contextlib.asynccontextmanager +async def create_session_async(): + """ + Context manager to create async session. + + :meta private: + """ + from airflow.settings import AsyncSession + + async with AsyncSession() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + + PS = ParamSpec("PS") RT = TypeVar("RT") diff --git a/tests/utils/test_session.py b/tests/utils/test_session.py index 02cba9e070dc4..8d26a25c626a5 100644 --- a/tests/utils/test_session.py +++ b/tests/utils/test_session.py @@ -58,9 +58,9 @@ def test_provide_session_with_kwargs(self): @pytest.mark.asyncio async def test_async_session(self): - from airflow.settings import create_async_session + from airflow.settings import AsyncSession - session = create_async_session() + session = AsyncSession() session.add(Log(event="hihi1234")) await session.commit() my_special_log_event = await session.scalar(select(Log).where(Log.event == "hihi1234").limit(1))