Skip to content

Commit

Permalink
AIP-84 Introduce SessionDep and AsyncSessionDep construct (apache#44461)
Browse files Browse the repository at this point in the history
  • Loading branch information
pierrejeambrun authored Nov 29, 2024
1 parent b2d2bcb commit b882246
Show file tree
Hide file tree
Showing 27 changed files with 140 additions and 200 deletions.
39 changes: 11 additions & 28 deletions airflow/api_fastapi/common/db/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,38 +23,29 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING, Literal, overload
from typing import TYPE_CHECKING, Annotated, Literal, overload

from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session

from airflow.utils.db import get_query_count, get_query_count_async
from airflow.utils.session import NEW_SESSION, create_session, create_session_async, provide_session

if TYPE_CHECKING:
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select

from airflow.api_fastapi.common.parameters import BaseParam


def get_session() -> Session:
"""
Dependency for providing a session.
For non route function please use the :class:`airflow.utils.session.provide_session` decorator.
Example usage:
.. code:: python
@router.get("/your_path")
def your_route(session: Annotated[Session, Depends(get_session)]):
pass
"""
def _get_session() -> Session:
with create_session(scoped=False) as session:
yield session


SessionDep = Annotated[Session, Depends(_get_session)]


def apply_filters_to_select(
*, statement: Select, filters: Sequence[BaseParam | None] | None = None
) -> Select:
Expand All @@ -68,22 +59,14 @@ def apply_filters_to_select(
return statement


async def get_async_session() -> AsyncSession:
"""
Dependency for providing a session.
Example usage:
.. code:: python
@router.get("/your_path")
def your_route(session: Annotated[AsyncSession, Depends(get_async_session)]):
pass
"""
async def _get_async_session() -> AsyncSession:
async with create_session_async() as session:
yield session


AsyncSessionDep = Annotated[AsyncSession, Depends(_get_async_session)]


@overload
async def paginated_select_async(
*,
Expand Down
24 changes: 12 additions & 12 deletions airflow/api_fastapi/core_api/routes/public/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@

from fastapi import Depends, HTTPException, status
from sqlalchemy import delete, select
from sqlalchemy.orm import Session, joinedload, subqueryload
from sqlalchemy.orm import joinedload, subqueryload

from airflow.api_fastapi.common.db.common import get_session, paginated_select
from airflow.api_fastapi.common.db.common import SessionDep, paginated_select
from airflow.api_fastapi.common.parameters import (
OptionalDateTimeQuery,
QueryAssetDagIdPatternSearch,
Expand Down Expand Up @@ -91,7 +91,7 @@ def get_assets(
SortParam,
Depends(SortParam(["id", "uri", "created_at", "updated_at"], AssetModel).dynamic_depends()),
],
session: Annotated[Session, Depends(get_session)],
session: SessionDep,
) -> AssetCollectionResponse:
"""Get assets."""
assets_select, total_entries = paginated_select(
Expand Down Expand Up @@ -141,7 +141,7 @@ def get_asset_events(
source_task_id: QuerySourceTaskIdFilter,
source_run_id: QuerySourceRunIdFilter,
source_map_index: QuerySourceMapIndexFilter,
session: Annotated[Session, Depends(get_session)],
session: SessionDep,
) -> AssetEventCollectionResponse:
"""Get asset events."""
assets_event_select, total_entries = paginated_select(
Expand All @@ -168,7 +168,7 @@ def get_asset_events(
)
def create_asset_event(
body: CreateAssetEventsBody,
session: Annotated[Session, Depends(get_session)],
session: SessionDep,
) -> AssetEventResponse:
"""Create asset events."""
asset = session.scalar(select(AssetModel).where(AssetModel.uri == body.uri).limit(1))
Expand Down Expand Up @@ -198,7 +198,7 @@ def create_asset_event(
)
def get_asset_queued_events(
uri: str,
session: Annotated[Session, Depends(get_session)],
session: SessionDep,
before: OptionalDateTimeQuery = None,
) -> QueuedEventCollectionResponse:
"""Get queued asset events for an asset."""
Expand Down Expand Up @@ -233,7 +233,7 @@ def get_asset_queued_events(
)
def get_asset(
uri: str,
session: Annotated[Session, Depends(get_session)],
session: SessionDep,
) -> AssetResponse:
"""Get an asset."""
asset = session.scalar(
Expand All @@ -258,7 +258,7 @@ def get_asset(
)
def get_dag_asset_queued_events(
dag_id: str,
session: Annotated[Session, Depends(get_session)],
session: SessionDep,
before: OptionalDateTimeQuery = None,
) -> QueuedEventCollectionResponse:
"""Get queued asset events for a DAG."""
Expand Down Expand Up @@ -296,7 +296,7 @@ def get_dag_asset_queued_events(
def get_dag_asset_queued_event(
dag_id: str,
uri: str,
session: Annotated[Session, Depends(get_session)],
session: SessionDep,
before: OptionalDateTimeQuery = None,
) -> QueuedEventResponse:
"""Get a queued asset event for a DAG."""
Expand Down Expand Up @@ -327,7 +327,7 @@ def get_dag_asset_queued_event(
)
def delete_asset_queued_events(
uri: str,
session: Annotated[Session, Depends(get_session)],
session: SessionDep,
before: OptionalDateTimeQuery = None,
):
"""Delete queued asset events for an asset."""
Expand All @@ -350,7 +350,7 @@ def delete_asset_queued_events(
)
def delete_dag_asset_queued_events(
dag_id: str,
session: Annotated[Session, Depends(get_session)],
session: SessionDep,
before: OptionalDateTimeQuery = None,
):
where_clause = _generate_queued_event_where_clause(dag_id=dag_id, before=before)
Expand All @@ -375,7 +375,7 @@ def delete_dag_asset_queued_events(
def delete_dag_asset_queued_event(
dag_id: str,
uri: str,
session: Annotated[Session, Depends(get_session)],
session: SessionDep,
before: OptionalDateTimeQuery = None,
):
"""Delete a queued asset event for a DAG."""
Expand Down
18 changes: 10 additions & 8 deletions airflow/api_fastapi/core_api/routes/public/backfills.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@

from fastapi import Depends, HTTPException, status
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session

from airflow.api_fastapi.common.db.common import get_async_session, get_session, paginated_select_async
from airflow.api_fastapi.common.db.common import (
AsyncSessionDep,
SessionDep,
paginated_select_async,
)
from airflow.api_fastapi.common.parameters import QueryLimit, QueryOffset, SortParam
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.datamodels.backfills import (
Expand Down Expand Up @@ -58,7 +60,7 @@ async def list_backfills(
SortParam,
Depends(SortParam(["id"], Backfill).dynamic_depends()),
],
session: Annotated[AsyncSession, Depends(get_async_session)],
session: AsyncSessionDep,
) -> BackfillCollectionResponse:
select_stmt, total_entries = await paginated_select_async(
statement=select(Backfill).where(Backfill.dag_id == dag_id),
Expand All @@ -80,7 +82,7 @@ async def list_backfills(
)
def get_backfill(
backfill_id: str,
session: Annotated[Session, Depends(get_session)],
session: SessionDep,
) -> BackfillResponse:
backfill = session.get(Backfill, backfill_id)
if backfill:
Expand All @@ -97,7 +99,7 @@ def get_backfill(
]
),
)
def pause_backfill(backfill_id, session: Annotated[Session, Depends(get_session)]) -> BackfillResponse:
def pause_backfill(backfill_id, session: SessionDep) -> BackfillResponse:
b = session.get(Backfill, backfill_id)
if not b:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Could not find backfill with id {backfill_id}")
Expand All @@ -118,7 +120,7 @@ def pause_backfill(backfill_id, session: Annotated[Session, Depends(get_session)
]
),
)
def unpause_backfill(backfill_id, session: Annotated[Session, Depends(get_session)]) -> BackfillResponse:
def unpause_backfill(backfill_id, session: SessionDep) -> BackfillResponse:
b = session.get(Backfill, backfill_id)
if not b:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Could not find backfill with id {backfill_id}")
Expand All @@ -138,7 +140,7 @@ def unpause_backfill(backfill_id, session: Annotated[Session, Depends(get_sessio
]
),
)
def cancel_backfill(backfill_id, session: Annotated[Session, Depends(get_session)]) -> BackfillResponse:
def cancel_backfill(backfill_id, session: SessionDep) -> BackfillResponse:
b: Backfill = session.get(Backfill, backfill_id)
if not b:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Could not find backfill with id {backfill_id}")
Expand Down
13 changes: 6 additions & 7 deletions airflow/api_fastapi/core_api/routes/public/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@

from fastapi import Depends, HTTPException, Query, status
from sqlalchemy import select
from sqlalchemy.orm import Session

from airflow.api_fastapi.common.db.common import get_session, paginated_select
from airflow.api_fastapi.common.db.common import SessionDep, paginated_select
from airflow.api_fastapi.common.parameters import QueryLimit, QueryOffset, SortParam
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.datamodels.connections import (
Expand All @@ -49,7 +48,7 @@
)
def delete_connection(
connection_id: str,
session: Annotated[Session, Depends(get_session)],
session: SessionDep,
):
"""Delete a connection entry."""
connection = session.scalar(select(Connection).filter_by(conn_id=connection_id))
Expand All @@ -68,7 +67,7 @@ def delete_connection(
)
def get_connection(
connection_id: str,
session: Annotated[Session, Depends(get_session)],
session: SessionDep,
) -> ConnectionResponse:
"""Get a connection entry."""
connection = session.scalar(select(Connection).filter_by(conn_id=connection_id))
Expand Down Expand Up @@ -98,7 +97,7 @@ def get_connections(
).dynamic_depends()
),
],
session: Annotated[Session, Depends(get_session)],
session: SessionDep,
) -> ConnectionCollectionResponse:
"""Get all connection entries."""
connection_select, total_entries = paginated_select(
Expand All @@ -124,7 +123,7 @@ def get_connections(
)
def post_connection(
post_body: ConnectionBody,
session: Annotated[Session, Depends(get_session)],
session: SessionDep,
) -> ConnectionResponse:
"""Create connection entry."""
try:
Expand Down Expand Up @@ -157,7 +156,7 @@ def post_connection(
def patch_connection(
connection_id: str,
patch_body: ConnectionBody,
session: Annotated[Session, Depends(get_session)],
session: SessionDep,
update_mask: list[str] | None = Query(None),
) -> ConnectionResponse:
"""Update a connection entry."""
Expand Down
9 changes: 4 additions & 5 deletions airflow/api_fastapi/core_api/routes/public/dag_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING, Annotated
from typing import TYPE_CHECKING

from fastapi import Depends, HTTPException, Request, status
from fastapi import HTTPException, Request, status
from itsdangerous import BadSignature, URLSafeSerializer
from sqlalchemy import select
from sqlalchemy.orm import Session

from airflow.api_fastapi.common.db.common import get_session
from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.auth.managers.models.resource_details import DagDetails
Expand All @@ -44,7 +43,7 @@
)
def reparse_dag_file(
file_token: str,
session: Annotated[Session, Depends(get_session)],
session: SessionDep,
request: Request,
) -> None:
"""Request re-parsing a DAG file."""
Expand Down
Loading

0 comments on commit b882246

Please sign in to comment.