From aafe26b517e4a3e53428574b182ac8f27df276e6 Mon Sep 17 00:00:00 2001 From: Eugene Shershen Date: Thu, 17 Oct 2024 22:39:11 +0300 Subject: [PATCH] fixes mypy --- fastapi_async_sqlalchemy/middleware.py | 27 ++++++++++++++++++++++++++ requirements.txt | 4 ++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/fastapi_async_sqlalchemy/middleware.py b/fastapi_async_sqlalchemy/middleware.py index 6893e28..f70a49e 100644 --- a/fastapi_async_sqlalchemy/middleware.py +++ b/fastapi_async_sqlalchemy/middleware.py @@ -23,6 +23,9 @@ def create_middleware_and_session_proxy(): _session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None) _multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False) _commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False) + # Usage of context vars inside closures is not recommended, since they are not properly + # garbage collected, but in our use case context var is created on program startup and + # is used throughout the whole its lifecycle. class SQLAlchemyMiddleware(BaseHTTPMiddleware): def __init__( @@ -58,11 +61,35 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): class DBSessionMeta(type): @property def session(self) -> AsyncSession: + """Return an instance of Session local to the current async context.""" if _Session is None: raise SessionNotInitialisedError multi_sessions = _multi_sessions_ctx.get() if multi_sessions: + """If multi_sessions is True, we are in a context where multiple sessions are allowed. + In this case, we need to create a new session for each task. + We also need to commit the session on exit if commit_on_exit is True. + This is useful when we need to run multiple queries in parallel. + For example, when we need to run multiple queries in parallel in a route handler. + Example: + ```python + async with db(multi_sessions=True): + async def execute_query(query): + return await db.session.execute(text(query)) + + tasks = [ + asyncio.create_task(execute_query("SELECT 1")), + asyncio.create_task(execute_query("SELECT 2")), + asyncio.create_task(execute_query("SELECT 3")), + asyncio.create_task(execute_query("SELECT 4")), + asyncio.create_task(execute_query("SELECT 5")), + asyncio.create_task(execute_query("SELECT 6")), + ] + + await asyncio.gather(*tasks) + ``` + """ commit_on_exit = _commit_on_exit_ctx.get() task: Task = asyncio.current_task() # type: ignore if not hasattr(task, "_db_session"): diff --git a/requirements.txt b/requirements.txt index d3232d8..e3a0644 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ coverage>=5.2.1 entrypoints==0.3 fastapi==0.90.0 # pyup: ignore flake8==3.7.9 -idna==2.8 +idna==3.7 importlib-metadata==1.5.0 isort==4.3.21 mccabe==0.6.1 @@ -36,7 +36,7 @@ toml>=0.10.1 typed-ast>=1.4.2 urllib3>=1.25.9 wcwidth==0.1.7 -zipp==3.1.0 +zipp==3.19.1 black==24.4.2 pytest-asyncio==0.21.0 greenlet==3.1.1