Skip to content

Commit

Permalink
WIP: multi_sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
Eugene Shershen committed Oct 18, 2024
1 parent 83d7bef commit 0c74aaf
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 38 deletions.
2 changes: 1 addition & 1 deletion fastapi_async_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

__all__ = ["db", "SQLAlchemyMiddleware"]

__version__ = "0.7.0.dev1"
__version__ = "0.7.0.dev2"
75 changes: 38 additions & 37 deletions fastapi_async_sqlalchemy/middleware.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
from asyncio import Task
from contextvars import ContextVar
from typing import Dict, Optional, Union

Expand All @@ -22,6 +21,9 @@ def create_middleware_and_session_proxy():
_Session: Optional[async_sessionmaker] = None
_session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None)
_multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False)
_task_session_ctx: ContextVar[Optional[AsyncSession]] = ContextVar(
"_task_session_ctx", default=None
)
_commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False)
# Usage of context vars inside closures is not recommended, since they are not properly
# garbage collected, but in our use case context var is created on program startup and
Expand Down Expand Up @@ -90,28 +92,26 @@ async def execute_query(query):
```
"""
commit_on_exit = _commit_on_exit_ctx.get()
task: Task = asyncio.current_task() # type: ignore
if not hasattr(task, "_db_session"):
task._db_session = _Session() # type: ignore

def cleanup(future):
session = getattr(task, "_db_session", None)
if session:

async def do_cleanup():
try:
if future.exception():
await session.rollback()
else:
if commit_on_exit:
await session.commit()
finally:
await session.close()

asyncio.create_task(do_cleanup())

task.add_done_callback(cleanup)
return task._db_session # type: ignore
session = _task_session_ctx.get()
if session is None:
session = _Session()
_task_session_ctx.set(session)

async def cleanup():
try:
if commit_on_exit:
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
_task_session_ctx.set(None)

task = asyncio.current_task()
if task is not None:
task.add_done_callback(lambda t: asyncio.create_task(cleanup()))
return session
else:
session = _session.get()
if session is None:
Expand Down Expand Up @@ -139,23 +139,24 @@ async def __aenter__(self):
if self.multi_sessions:
self.multi_sessions_token = _multi_sessions_ctx.set(True)
self.commit_on_exit_token = _commit_on_exit_ctx.set(self.commit_on_exit)

self.token = _session.set(_Session(**self.session_args))
else:
self.token = _session.set(_Session(**self.session_args))
return type(self)

async def __aexit__(self, exc_type, exc_value, traceback):
session = _session.get()
try:
if exc_type is not None:
await session.rollback()
elif self.commit_on_exit:
await session.commit()
finally:
await session.close()
_session.reset(self.token)
if self.multi_sessions_token is not None:
_multi_sessions_ctx.reset(self.multi_sessions_token)
_commit_on_exit_ctx.reset(self.commit_on_exit_token)
if self.multi_sessions:
_multi_sessions_ctx.reset(self.multi_sessions_token)
_commit_on_exit_ctx.reset(self.commit_on_exit_token)
else:
session = _session.get()
try:
if exc_type is not None:
await session.rollback()
elif self.commit_on_exit:
await session.commit()
finally:
await session.close()
_session.reset(self.token)

return SQLAlchemyMiddleware, DBSession

Expand Down
25 changes: 25 additions & 0 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,28 @@ async def execute_query(query):

res = await asyncio.gather(*tasks)
assert len(res) == 6


@pytest.mark.asyncio
async def test_concurrent_inserts(app, db, SQLAlchemyMiddleware):
app.add_middleware(SQLAlchemyMiddleware, db_url=db_url)

async with db(multi_sessions=True, commit_on_exit=True):
await db.session.execute(
text("CREATE TABLE IF NOT EXISTS my_model (id INTEGER PRIMARY KEY, value TEXT)")
)

async def insert_data(value):
await db.session.execute(
text("INSERT INTO my_model (value) VALUES (:value)"), {"value": value}
)
await db.session.flush()

tasks = [asyncio.create_task(insert_data(f"value_{i}")) for i in range(10)]

result_ids = await asyncio.gather(*tasks)
assert len(result_ids) == 10

records = await db.session.execute(text("SELECT * FROM my_model"))
records = records.scalars().all()
assert len(records) == 10

0 comments on commit 0c74aaf

Please sign in to comment.