Skip to content

Commit

Permalink
fixes mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
Eugene Shershen committed Oct 17, 2024
1 parent 5c9f01f commit aafe26b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
27 changes: 27 additions & 0 deletions fastapi_async_sqlalchemy/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def create_middleware_and_session_proxy():
_session: ContextVar[Optional[AsyncSession]] = ContextVar("_session", default=None)
_multi_sessions_ctx: ContextVar[bool] = ContextVar("_multi_sessions_context", default=False)
_commit_on_exit_ctx: ContextVar[bool] = ContextVar("_commit_on_exit_ctx", default=False)
# Usage of context vars inside closures is not recommended, since they are not properly
# garbage collected, but in our use case context var is created on program startup and
# is used throughout the whole its lifecycle.

class SQLAlchemyMiddleware(BaseHTTPMiddleware):
def __init__(
Expand Down Expand Up @@ -58,11 +61,35 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
class DBSessionMeta(type):
@property
def session(self) -> AsyncSession:
"""Return an instance of Session local to the current async context."""
if _Session is None:
raise SessionNotInitialisedError

multi_sessions = _multi_sessions_ctx.get()
if multi_sessions:
"""If multi_sessions is True, we are in a context where multiple sessions are allowed.
In this case, we need to create a new session for each task.
We also need to commit the session on exit if commit_on_exit is True.
This is useful when we need to run multiple queries in parallel.
For example, when we need to run multiple queries in parallel in a route handler.
Example:
```python
async with db(multi_sessions=True):
async def execute_query(query):
return await db.session.execute(text(query))
tasks = [
asyncio.create_task(execute_query("SELECT 1")),
asyncio.create_task(execute_query("SELECT 2")),
asyncio.create_task(execute_query("SELECT 3")),
asyncio.create_task(execute_query("SELECT 4")),
asyncio.create_task(execute_query("SELECT 5")),
asyncio.create_task(execute_query("SELECT 6")),
]
await asyncio.gather(*tasks)
```
"""
commit_on_exit = _commit_on_exit_ctx.get()
task: Task = asyncio.current_task() # type: ignore
if not hasattr(task, "_db_session"):
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ coverage>=5.2.1
entrypoints==0.3
fastapi==0.90.0 # pyup: ignore
flake8==3.7.9
idna==2.8
idna==3.7
importlib-metadata==1.5.0
isort==4.3.21
mccabe==0.6.1
Expand Down Expand Up @@ -36,7 +36,7 @@ toml>=0.10.1
typed-ast>=1.4.2
urllib3>=1.25.9
wcwidth==0.1.7
zipp==3.1.0
zipp==3.19.1
black==24.4.2
pytest-asyncio==0.21.0
greenlet==3.1.1

0 comments on commit aafe26b

Please sign in to comment.