diff --git a/fastapi_async_sqlalchemy/__init__.py b/fastapi_async_sqlalchemy/__init__.py index 21c6edc..2821653 100644 --- a/fastapi_async_sqlalchemy/__init__.py +++ b/fastapi_async_sqlalchemy/__init__.py @@ -2,4 +2,4 @@ __all__ = ["db", "SQLAlchemyMiddleware"] -__version__ = "0.7.0.dev1" +__version__ = "0.7.0.dev2" diff --git a/fastapi_async_sqlalchemy/middleware.py b/fastapi_async_sqlalchemy/middleware.py index e3b1563..13cb359 100644 --- a/fastapi_async_sqlalchemy/middleware.py +++ b/fastapi_async_sqlalchemy/middleware.py @@ -1,5 +1,4 @@ import asyncio -from asyncio import Task from contextvars import ContextVar from typing import Dict, Optional, Union @@ -22,6 +21,9 @@ def create_middleware_and_session_proxy(): _Session: Optional[async_sessionmaker] = None _session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None) _multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False) + _task_session_ctx: ContextVar[Optional[AsyncSession]] = ContextVar( + "_task_session_ctx", default=None + ) _commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False) # Usage of context vars inside closures is not recommended, since they are not properly # garbage collected, but in our use case context var is created on program startup and @@ -90,28 +92,26 @@ async def execute_query(query): ``` """ commit_on_exit = _commit_on_exit_ctx.get() - task: Task = asyncio.current_task() # type: ignore - if not hasattr(task, "_db_session"): - task._db_session = _Session() # type: ignore - - def cleanup(future): - session = getattr(task, "_db_session", None) - if session: - - async def do_cleanup(): - try: - if future.exception(): - await session.rollback() - else: - if commit_on_exit: - await session.commit() - finally: - await session.close() - - asyncio.create_task(do_cleanup()) - - task.add_done_callback(cleanup) - return task._db_session # type: ignore + session = _task_session_ctx.get() + if session is None: + session = _Session() + _task_session_ctx.set(session) + + async def cleanup(): + try: + if commit_on_exit: + await session.commit() + except Exception: + await session.rollback() + raise + finally: + await session.close() + _task_session_ctx.set(None) + + task = asyncio.current_task() + if task is not None: + task.add_done_callback(lambda t: asyncio.create_task(cleanup())) + return session else: session = _session.get() if session is None: @@ -139,23 +139,24 @@ async def __aenter__(self): if self.multi_sessions: self.multi_sessions_token = _multi_sessions_ctx.set(True) self.commit_on_exit_token = _commit_on_exit_ctx.set(self.commit_on_exit) - - self.token = _session.set(_Session(**self.session_args)) + else: + self.token = _session.set(_Session(**self.session_args)) return type(self) async def __aexit__(self, exc_type, exc_value, traceback): - session = _session.get() - try: - if exc_type is not None: - await session.rollback() - elif self.commit_on_exit: - await session.commit() - finally: - await session.close() - _session.reset(self.token) - if self.multi_sessions_token is not None: - _multi_sessions_ctx.reset(self.multi_sessions_token) - _commit_on_exit_ctx.reset(self.commit_on_exit_token) + if self.multi_sessions: + _multi_sessions_ctx.reset(self.multi_sessions_token) + _commit_on_exit_ctx.reset(self.commit_on_exit_token) + else: + session = _session.get() + try: + if exc_type is not None: + await session.rollback() + elif self.commit_on_exit: + await session.commit() + finally: + await session.close() + _session.reset(self.token) return SQLAlchemyMiddleware, DBSession diff --git a/tests/test_session.py b/tests/test_session.py index 82f5dc9..9400fea 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -173,3 +173,28 @@ async def execute_query(query): res = await asyncio.gather(*tasks) assert len(res) == 6 + + +@pytest.mark.asyncio +async def test_concurrent_inserts(app, db, SQLAlchemyMiddleware): + app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + + async with db(multi_sessions=True, commit_on_exit=True): + await db.session.execute( + text("CREATE TABLE IF NOT EXISTS my_model (id INTEGER PRIMARY KEY, value TEXT)") + ) + + async def insert_data(value): + await db.session.execute( + text("INSERT INTO my_model (value) VALUES (:value)"), {"value": value} + ) + await db.session.flush() + + tasks = [asyncio.create_task(insert_data(f"value_{i}")) for i in range(10)] + + result_ids = await asyncio.gather(*tasks) + assert len(result_ids) == 10 + + records = await db.session.execute(text("SELECT * FROM my_model")) + records = records.scalars().all() + assert len(records) == 10