diff --git a/tests/test_session.py b/tests/test_session.py index 06163ac..c7e8666 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,4 +1,7 @@ +import asyncio + import pytest +from sqlalchemy import text from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from starlette.middleware.base import BaseHTTPMiddleware @@ -148,3 +151,24 @@ async def test_db_context_session_args(app, db, SQLAlchemyMiddleware, commit_on_ session_args = {"expire_on_commit": False} async with db(session_args=session_args): db.session + + +@pytest.mark.asyncio +async def test_multi_sessions(app, db, SQLAlchemyMiddleware): + app.add_middleware(SQLAlchemyMiddleware, db_url=db_url) + + async with db(multi_sessions=True): + async def execute_query(query): + return await db.session.execute(text(query)) + + tasks = [ + asyncio.create_task(execute_query("SELECT 1")), + asyncio.create_task(execute_query("SELECT 2")), + asyncio.create_task(execute_query("SELECT 3")), + asyncio.create_task(execute_query("SELECT 4")), + asyncio.create_task(execute_query("SELECT 5")), + asyncio.create_task(execute_query("SELECT 6")), + ] + + res = await asyncio.gather(*tasks) + assert len(res) == 6