Skip to content

Commit

Permalink
Merge pull request #22 from h0rn3t/task-local-sessions
Browse files Browse the repository at this point in the history
Task local sessions
  • Loading branch information
h0rn3t authored Oct 17, 2024
2 parents 52272ca + 83d7bef commit 083197a
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 22 deletions.
23 changes: 21 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,10 @@ app.add_middleware(
routes.py

```python
import asyncio

from fastapi import APIRouter
from sqlalchemy import column
from sqlalchemy import table
from sqlalchemy import column, table, text

from databases import first_db, second_db

Expand All @@ -147,4 +148,22 @@ async def get_files_from_first_db():
async def get_files_from_second_db():
result = await second_db.session.execute(foo.select())
return result.fetchall()


@router.get("/concurrent-queries")
async def parallel_select():
async with first_db(multi_sessions=True):
async def execute_query(query):
return await first_db.session.execute(text(query))

tasks = [
asyncio.create_task(execute_query("SELECT 1")),
asyncio.create_task(execute_query("SELECT 2")),
asyncio.create_task(execute_query("SELECT 3")),
asyncio.create_task(execute_query("SELECT 4")),
asyncio.create_task(execute_query("SELECT 5")),
asyncio.create_task(execute_query("SELECT 6")),
]

await asyncio.gather(*tasks)
```
2 changes: 1 addition & 1 deletion fastapi_async_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

__all__ = ["db", "SQLAlchemyMiddleware"]

__version__ = "0.6.1"
__version__ = "0.7.0.dev1"
91 changes: 77 additions & 14 deletions fastapi_async_sqlalchemy/middleware.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
import asyncio
from asyncio import Task
from contextvars import ContextVar
from typing import Dict, Optional, Union

from sqlalchemy.engine import Engine
from sqlalchemy.engine.url import URL
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.types import ASGIApp

from fastapi_async_sqlalchemy.exceptions import MissingSessionError, SessionNotInitialisedError

try:
from sqlalchemy.ext.asyncio import async_sessionmaker
from sqlalchemy.ext.asyncio import async_sessionmaker # noqa: F811
except ImportError:
from sqlalchemy.orm import sessionmaker as async_sessionmaker


def create_middleware_and_session_proxy():
_Session: Optional[async_sessionmaker] = None
_session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None)
_multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False)
_commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False)
# Usage of context vars inside closures is not recommended, since they are not properly
# garbage collected, but in our use case context var is created on program startup and
# is used throughout the whole its lifecycle.
_session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None)

class SQLAlchemyMiddleware(BaseHTTPMiddleware):
def __init__(
Expand Down Expand Up @@ -61,38 +65,97 @@ def session(self) -> AsyncSession:
if _Session is None:
raise SessionNotInitialisedError

session = _session.get()
if session is None:
raise MissingSessionError

return session
multi_sessions = _multi_sessions_ctx.get()
if multi_sessions:
"""In this case, we need to create a new session for each task.
We also need to commit the session on exit if commit_on_exit is True.
This is useful when we need to run multiple queries in parallel.
For example, when we need to run multiple queries in parallel in a route handler.
Example:
```python
async with db(multi_sessions=True):
async def execute_query(query):
return await db.session.execute(text(query))
tasks = [
asyncio.create_task(execute_query("SELECT 1")),
asyncio.create_task(execute_query("SELECT 2")),
asyncio.create_task(execute_query("SELECT 3")),
asyncio.create_task(execute_query("SELECT 4")),
asyncio.create_task(execute_query("SELECT 5")),
asyncio.create_task(execute_query("SELECT 6")),
]
await asyncio.gather(*tasks)
```
"""
commit_on_exit = _commit_on_exit_ctx.get()
task: Task = asyncio.current_task() # type: ignore
if not hasattr(task, "_db_session"):
task._db_session = _Session() # type: ignore

def cleanup(future):
session = getattr(task, "_db_session", None)
if session:

async def do_cleanup():
try:
if future.exception():
await session.rollback()
else:
if commit_on_exit:
await session.commit()
finally:
await session.close()

asyncio.create_task(do_cleanup())

task.add_done_callback(cleanup)
return task._db_session # type: ignore
else:
session = _session.get()
if session is None:
raise MissingSessionError
return session

class DBSession(metaclass=DBSessionMeta):
def __init__(self, session_args: Dict = None, commit_on_exit: bool = False):
def __init__(
self,
session_args: Dict = None,
commit_on_exit: bool = False,
multi_sessions: bool = False,
):
self.token = None
self.multi_sessions_token = None
self.commit_on_exit_token = None
self.session_args = session_args or {}
self.commit_on_exit = commit_on_exit
self.multi_sessions = multi_sessions

async def __aenter__(self):
if not isinstance(_Session, async_sessionmaker):
raise SessionNotInitialisedError

self.token = _session.set(_Session(**self.session_args)) # type: ignore
if self.multi_sessions:
self.multi_sessions_token = _multi_sessions_ctx.set(True)
self.commit_on_exit_token = _commit_on_exit_ctx.set(self.commit_on_exit)

self.token = _session.set(_Session(**self.session_args))
return type(self)

async def __aexit__(self, exc_type, exc_value, traceback):
session = _session.get()

try:
if exc_type is not None:
await session.rollback()
elif (
self.commit_on_exit
): # Note: Changed this to elif to avoid commit after rollback
elif self.commit_on_exit:
await session.commit()
finally:
await session.close()
_session.reset(self.token)
if self.multi_sessions_token is not None:
_multi_sessions_ctx.reset(self.multi_sessions_token)
_commit_on_exit_ctx.reset(self.commit_on_exit_token)

return SQLAlchemyMiddleware, DBSession

Expand Down
10 changes: 5 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ coverage>=5.2.1
entrypoints==0.3
fastapi==0.90.0 # pyup: ignore
flake8==3.7.9
idna==2.8
idna==3.7
importlib-metadata==1.5.0
isort==4.3.21
mccabe==0.6.1
Expand All @@ -17,13 +17,13 @@ packaging>=22.0
pathspec>=0.9.0
pluggy==0.13.0
pycodestyle==2.5.0
pydantic==1.10.13
pydantic==1.10.18
pyflakes==2.1.1
pyparsing==2.4.2
pytest==7.2.0
pytest-cov==2.11.1
PyYAML>=5.4
regex==2020.2.20
regex>=2020.2.20
requests>=2.22.0
httpx>=0.20.0
six==1.12.0
Expand All @@ -36,7 +36,7 @@ toml>=0.10.1
typed-ast>=1.4.2
urllib3>=1.25.9
wcwidth==0.1.7
zipp==3.1.0
zipp==3.19.1
black==24.4.2
pytest-asyncio==0.21.0
greenlet==2.0.2
greenlet==3.1.1
25 changes: 25 additions & 0 deletions tests/test_session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import asyncio

import pytest
from sqlalchemy import text
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from starlette.middleware.base import BaseHTTPMiddleware
Expand Down Expand Up @@ -148,3 +151,25 @@ async def test_db_context_session_args(app, db, SQLAlchemyMiddleware, commit_on_
session_args = {"expire_on_commit": False}
async with db(session_args=session_args):
db.session


@pytest.mark.asyncio
async def test_multi_sessions(app, db, SQLAlchemyMiddleware):
app.add_middleware(SQLAlchemyMiddleware, db_url=db_url)

async with db(multi_sessions=True):

async def execute_query(query):
return await db.session.execute(text(query))

tasks = [
asyncio.create_task(execute_query("SELECT 1")),
asyncio.create_task(execute_query("SELECT 2")),
asyncio.create_task(execute_query("SELECT 3")),
asyncio.create_task(execute_query("SELECT 4")),
asyncio.create_task(execute_query("SELECT 5")),
asyncio.create_task(execute_query("SELECT 6")),
]

res = await asyncio.gather(*tasks)
assert len(res) == 6

0 comments on commit 083197a

Please sign in to comment.