diff --git a/test/unit/webapps/test_request_scoped_sqlalchemy_sessions.py b/test/unit/webapps/test_request_scoped_sqlalchemy_sessions.py index 31d90e28a129..fbe4fd9cae6d 100644 --- a/test/unit/webapps/test_request_scoped_sqlalchemy_sessions.py +++ b/test/unit/webapps/test_request_scoped_sqlalchemy_sessions.py @@ -8,7 +8,10 @@ import pytest from fastapi import FastAPI from fastapi.param_functions import Depends -from httpx import AsyncClient +from httpx import ( + ASGITransport, + AsyncClient, +) from starlette_context import context as request_context from galaxy.app_unittest_utils.galaxy_mock import MockApp @@ -16,6 +19,7 @@ app = FastAPI() add_request_id_middleware(app) +transport = ASGITransport(app=app) GX_APP = None @@ -96,7 +100,7 @@ def assert_scoped_session_is_thread_local(gx_app): @pytest.mark.asyncio async def test_request_scoped_sa_session_single_request(): - async with AsyncClient(app=app, base_url="http://test") as client: + async with AsyncClient(base_url="http://test", transport=transport) as client: response = await client.get("/") assert response.status_code == 200 assert response.json() == {"msg": "Hello World"} @@ -106,7 +110,7 @@ async def test_request_scoped_sa_session_single_request(): @pytest.mark.asyncio async def test_request_scoped_sa_session_exception(): - async with AsyncClient(app=app, base_url="http://test") as client: + async with AsyncClient(base_url="http://test", transport=transport) as client: with pytest.raises(UnexpectedException): await client.get("/internal_server_error") assert GX_APP @@ -115,7 +119,7 @@ async def test_request_scoped_sa_session_exception(): @pytest.mark.asyncio async def test_request_scoped_sa_session_concurrent_requests_sync(): - async with AsyncClient(app=app, base_url="http://test") as client: + async with AsyncClient(base_url="http://test", transport=transport) as client: awaitables = (client.get("/sync_wait") for _ in range(10)) result = await asyncio.gather(*awaitables) uuids = [] @@ -129,7 +133,7 @@ async def test_request_scoped_sa_session_concurrent_requests_sync(): @pytest.mark.asyncio async def test_request_scoped_sa_session_concurrent_requests_async(): - async with AsyncClient(app=app, base_url="http://test") as client: + async with AsyncClient(base_url="http://test", transport=transport) as client: awaitables = (client.get("/async_wait") for _ in range(10)) result = await asyncio.gather(*awaitables) uuids = [] @@ -147,7 +151,7 @@ async def test_request_scoped_sa_session_concurrent_requests_and_background_thre target = functools.partial(assert_scoped_session_is_thread_local, GX_APP) with concurrent.futures.ThreadPoolExecutor() as pool: background_pool = loop.run_in_executor(pool, target) - async with AsyncClient(app=app, base_url="http://test") as client: + async with AsyncClient(base_url="http://test", transport=transport) as client: awaitables = (client.get("/async_wait") for _ in range(10)) result = await asyncio.gather(*awaitables) uuids = []