-
Notifications
You must be signed in to change notification settings - Fork 315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Fix DB unittest reliability #4548
Changes from all commits
0e5a10c
148a6ea
06e1ff9
12b25dd
c3d785f
2397aac
9095d7a
aca3124
2d73c48
085a5f6
0618dfe
adae758
aa46b38
c16c952
ee87dad
d445cc4
66b5735
9cb9d1d
ce46a5b
9200ed9
f1e1183
f2cf7db
56bcaab
17e7c36
036a170
165bec1
7b0bcbc
3ca5b06
07b218a
a48660c
0c7d0fa
fc4da4c
b76949f
60b2ad3
87562c6
3861e3a
4fe6552
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
import asyncio | ||
import contextlib | ||
import time | ||
from asyncio import AbstractEventLoop, get_running_loop | ||
import os | ||
import tempfile | ||
from asyncio import AbstractEventLoop | ||
from functools import partial | ||
from importlib.metadata import version | ||
from random import getrandbits | ||
|
@@ -31,17 +32,20 @@ | |
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession | ||
from starlette.types import ASGIApp | ||
|
||
import phoenix.trace.v1 as pb | ||
from phoenix.config import EXPORT_DIR | ||
from phoenix.core.model_schema_adapter import create_model_from_inferences | ||
from phoenix.db import models | ||
from phoenix.db.bulk_inserter import BulkInserter | ||
from phoenix.db.engines import aio_postgresql_engine, aio_sqlite_engine | ||
from phoenix.db.insertion.helpers import DataManipulation | ||
from phoenix.inferences.inferences import EMPTY_INFERENCES | ||
from phoenix.pointcloud.umap_parameters import get_umap_parameters | ||
from phoenix.server.app import _db, create_app | ||
from phoenix.server.grpc_server import GrpcServer | ||
from phoenix.server.types import BatchedCaller, DbSessionFactory | ||
from phoenix.session.client import Client | ||
from phoenix.trace.schemas import Span | ||
|
||
|
||
def pytest_addoption(parser: Parser) -> None: | ||
|
@@ -65,20 +69,21 @@ def pytest_terminal_summary( | |
|
||
xfail_threshold = 12 # our tests are currently quite unreliable | ||
|
||
terminalreporter.write_sep("=", f"xfail threshold: {xfail_threshold}") | ||
terminalreporter.write_sep("=", f"xpasses: {xpasses}, xfails: {xfails}") | ||
if config.getoption("--run-postgres"): | ||
terminalreporter.write_sep("=", f"xfail threshold: {xfail_threshold}") | ||
terminalreporter.write_sep("=", f"xpasses: {xpasses}, xfails: {xfails}") | ||
|
||
if exitstatus == pytest.ExitCode.OK: | ||
if xfails < xfail_threshold: | ||
terminalreporter.write_sep( | ||
"=", "Within xfail threshold. Passing the test suite.", green=True | ||
) | ||
terminalreporter._session.exitstatus = pytest.ExitCode.OK | ||
else: | ||
terminalreporter.write_sep( | ||
"=", "Too many flaky tests. Failing the test suite.", red=True | ||
) | ||
terminalreporter._session.exitstatus = pytest.ExitCode.TESTS_FAILED | ||
if exitstatus == pytest.ExitCode.OK: | ||
if xfails < xfail_threshold: | ||
terminalreporter.write_sep( | ||
"=", "Within xfail threshold. Passing the test suite.", green=True | ||
) | ||
terminalreporter._session.exitstatus = pytest.ExitCode.OK | ||
else: | ||
terminalreporter.write_sep( | ||
"=", "Too many flaky tests. Failing the test suite.", red=True | ||
) | ||
terminalreporter._session.exitstatus = pytest.ExitCode.TESTS_FAILED | ||
|
||
|
||
def pytest_collection_modifyitems(config: Config, items: List[Any]) -> None: | ||
|
@@ -89,11 +94,6 @@ def pytest_collection_modifyitems(config: Config, items: List[Any]) -> None: | |
if "postgresql" in item.callspec.params.values(): | ||
item.add_marker(skip_postgres) | ||
|
||
if config.getoption("--allow-flaky"): | ||
for item in items: | ||
if "dialect" in item.fixturenames: | ||
item.add_marker(pytest.mark.xfail(reason="database tests are currently flaky")) | ||
|
||
|
||
@pytest.fixture | ||
def pydantic_version() -> Literal["v1", "v2"]: | ||
|
@@ -116,7 +116,7 @@ def openai_api_key(monkeypatch: pytest.MonkeyPatch) -> str: | |
postgresql_connection = factories.postgresql("postgresql_proc") | ||
|
||
|
||
@pytest.fixture() | ||
@pytest.fixture(scope="function") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't "function" the default scope? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, but it's better to be explicit in case the default changes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay. fwiw i doubt pytest will change that default since it would be a very disruptive change. |
||
async def postgresql_url(postgresql_connection: Connection) -> AsyncIterator[URL]: | ||
connection = postgresql_connection | ||
user = connection.info.user | ||
|
@@ -127,10 +127,11 @@ async def postgresql_url(postgresql_connection: Connection) -> AsyncIterator[URL | |
yield make_url(f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}") | ||
|
||
|
||
@pytest.fixture | ||
@pytest.fixture(scope="function") | ||
async def postgresql_engine(postgresql_url: URL) -> AsyncIterator[AsyncEngine]: | ||
engine = aio_postgresql_engine(postgresql_url, migrate=False) | ||
async with engine.begin() as conn: | ||
await conn.run_sync(models.Base.metadata.drop_all) | ||
await conn.run_sync(models.Base.metadata.create_all) | ||
yield engine | ||
await engine.dispose() | ||
|
@@ -141,33 +142,36 @@ def dialect(request: SubRequest) -> str: | |
return request.param | ||
|
||
|
||
@pytest.fixture | ||
@pytest.fixture(scope="function") | ||
async def sqlite_engine() -> AsyncIterator[AsyncEngine]: | ||
engine = aio_sqlite_engine(make_url("sqlite+aiosqlite://"), migrate=False, shared_cache=False) | ||
async with engine.begin() as conn: | ||
await conn.run_sync(models.Base.metadata.create_all) | ||
yield engine | ||
await engine.dispose() | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
db_file = os.path.join(temp_dir, "test.db") | ||
engine = aio_sqlite_engine(make_url(f"sqlite+aiosqlite:///{db_file}"), migrate=False) | ||
async with engine.begin() as conn: | ||
await conn.run_sync(models.Base.metadata.drop_all) | ||
await conn.run_sync(models.Base.metadata.create_all) | ||
yield engine | ||
await engine.dispose() | ||
|
||
|
||
@pytest.fixture | ||
@pytest.fixture(scope="function") | ||
def db( | ||
request: SubRequest, | ||
dialect: str, | ||
) -> DbSessionFactory: | ||
if dialect == "sqlite": | ||
return _db_with_lock(request.getfixturevalue("sqlite_engine")) | ||
return db_session_factory(request.getfixturevalue("sqlite_engine")) | ||
elif dialect == "postgresql": | ||
return _db_with_lock(request.getfixturevalue("postgresql_engine")) | ||
return db_session_factory(request.getfixturevalue("postgresql_engine")) | ||
raise ValueError(f"Unknown db fixture: {dialect}") | ||
|
||
|
||
def _db_with_lock(engine: AsyncEngine) -> DbSessionFactory: | ||
lock, db = asyncio.Lock(), _db(engine) | ||
def db_session_factory(engine: AsyncEngine) -> DbSessionFactory: | ||
db = _db(engine, bypass_lock=True) | ||
|
||
@contextlib.asynccontextmanager | ||
async def factory() -> AsyncIterator[AsyncSession]: | ||
async with lock, db() as session: | ||
async with db() as session: | ||
yield session | ||
|
||
return DbSessionFactory(db=factory, dialect=engine.dialect.name) | ||
|
@@ -186,7 +190,6 @@ async def app( | |
) -> AsyncIterator[ASGIApp]: | ||
async with contextlib.AsyncExitStack() as stack: | ||
await stack.enter_async_context(patch_batched_caller()) | ||
await stack.enter_async_context(patch_bulk_inserter()) | ||
await stack.enter_async_context(patch_grpc_server()) | ||
app = create_app( | ||
db=db, | ||
|
@@ -195,57 +198,48 @@ async def app( | |
export_path=EXPORT_DIR, | ||
umap_params=get_umap_parameters(None), | ||
serve_ui=False, | ||
bulk_inserter_factory=TestBulkInserter, | ||
) | ||
manager = await stack.enter_async_context(LifespanManager(app)) | ||
yield manager.app | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def event_loop_policy(): | ||
try: | ||
import uvloop | ||
except ImportError: | ||
return asyncio.DefaultEventLoopPolicy() | ||
return uvloop.EventLoopPolicy() | ||
|
||
|
||
@pytest.fixture | ||
async def loop() -> AbstractEventLoop: | ||
return get_running_loop() | ||
|
||
|
||
@pytest.fixture | ||
def httpx_clients( | ||
app: ASGIApp, | ||
loop: AbstractEventLoop, | ||
) -> Tuple[httpx.Client, httpx.AsyncClient]: | ||
class Transport(httpx.BaseTransport, httpx.AsyncBaseTransport): | ||
def __init__(self, transport: httpx.ASGITransport) -> None: | ||
self.transport = transport | ||
class Transport(httpx.BaseTransport): | ||
def __init__(self, app, asgi_transport): | ||
import nest_asyncio | ||
|
||
nest_asyncio.apply() | ||
|
||
self.app = app | ||
self.asgi_transport = asgi_transport | ||
|
||
def handle_request(self, request: Request) -> Response: | ||
fut = loop.create_task(self.handle_async_request(request)) | ||
time_cutoff = time.time() + 10 | ||
while not fut.done() and time.time() < time_cutoff: | ||
time.sleep(0.01) | ||
if fut.done(): | ||
return fut.result() | ||
raise TimeoutError | ||
|
||
async def handle_async_request(self, request: Request) -> Response: | ||
response = await self.transport.handle_async_request(request) | ||
response = asyncio.run(self.asgi_transport.handle_async_request(request)) | ||
|
||
async def read_stream(): | ||
content = b"" | ||
async for chunk in response.stream: | ||
content += chunk | ||
return content | ||
|
||
content = asyncio.run(read_stream()) | ||
return Response( | ||
status_code=response.status_code, | ||
headers=response.headers, | ||
content=b"".join([_ async for _ in response.stream]), | ||
content=content, | ||
request=request, | ||
) | ||
|
||
transport = Transport(httpx.ASGITransport(app)) | ||
asgi_transport = httpx.ASGITransport(app=app) | ||
transport = Transport(httpx.ASGITransport(app), asgi_transport=asgi_transport) | ||
base_url = "http://test" | ||
return ( | ||
httpx.Client(transport=transport, base_url=base_url), | ||
httpx.AsyncClient(transport=transport, base_url=base_url), | ||
httpx.AsyncClient(transport=asgi_transport, base_url=base_url), | ||
) | ||
|
||
|
||
|
@@ -284,15 +278,42 @@ async def patch_grpc_server() -> AsyncIterator[None]: | |
setattr(cls, name, original) | ||
|
||
|
||
@contextlib.asynccontextmanager | ||
async def patch_bulk_inserter() -> AsyncIterator[None]: | ||
cls = BulkInserter | ||
original = cls.__init__ | ||
name = original.__name__ | ||
changes = {"sleep": 0.001, "retry_delay_sec": 0.001, "retry_allowance": 1000} | ||
setattr(cls, name, lambda *_, **__: original(*_, **{**__, **changes})) | ||
yield | ||
setattr(cls, name, original) | ||
class TestBulkInserter(BulkInserter): | ||
async def __aenter__( | ||
self, | ||
) -> Tuple[ | ||
Callable[[Any], Awaitable[None]], | ||
Callable[[Span, str], Awaitable[None]], | ||
Callable[[pb.Evaluation], Awaitable[None]], | ||
Callable[[DataManipulation], None], | ||
]: | ||
# Return the overridden methods | ||
return ( | ||
self._enqueue_immediate, | ||
self._queue_span_immediate, | ||
self._queue_evaluation_immediate, | ||
self._enqueue_operation_immediate, | ||
) | ||
|
||
async def __aexit__(self, *args: Any) -> None: | ||
# No background tasks to cancel | ||
pass | ||
|
||
async def _enqueue_immediate(self, *items: Any) -> None: | ||
# Process items immediately | ||
await self._queue_inserters.enqueue(*items) | ||
async for event in self._queue_inserters.insert(): | ||
self._event_queue.put(event) | ||
|
||
async def _enqueue_operation_immediate(self, operation: DataManipulation) -> None: | ||
async with self._db() as session: | ||
await operation(session) | ||
|
||
async def _queue_span_immediate(self, span: Span, project_name: str) -> None: | ||
await self._insert_spans([(span, project_name)]) | ||
|
||
async def _queue_evaluation_immediate(self, evaluation: pb.Evaluation) -> None: | ||
await self._insert_evaluations([evaluation]) | ||
|
||
|
||
@contextlib.asynccontextmanager | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we no longer need
pytest-xdist
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how many cores are we using if we don't specify this option? do we still need
pytest-xdist
?