From 3ca5b06d64a3ab5c15861935c171fc902fa1186c Mon Sep 17 00:00:00 2001 From: Dustin Ngo Date: Tue, 17 Sep 2024 13:51:51 -0400 Subject: [PATCH] BulkInsterters insert immediately in tests --- src/phoenix/server/app.py | 4 +- tests/conftest.py | 75 ++++++++++++++++++++++-------- tests/datasets/test_experiments.py | 15 +++--- 3 files changed, 67 insertions(+), 27 deletions(-) diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 9931e32de1..cbdfd62f65 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -630,7 +630,9 @@ def create_app( scaffolder_config: Optional[ScaffolderConfig] = None, email_sender: Optional[EmailSender] = None, oauth2_client_configs: Optional[List[OAuth2ClientConfig]] = None, + bulk_inserter_factory: Optional[Callable[..., BulkInserter]] = None, ) -> FastAPI: + bulk_inserter_factory = bulk_inserter_factory or BulkInserter startup_callbacks_list: List[_Callback] = list(startup_callbacks) shutdown_callbacks_list: List[_Callback] = list(shutdown_callbacks) startup_callbacks_list.append(Facilitator(db=db)) @@ -663,7 +665,7 @@ def create_app( cache_for_dataloaders=cache_for_dataloaders, last_updated_at=last_updated_at, ) - bulk_inserter = BulkInserter( + bulk_inserter = bulk_inserter_factory( db, enable_prometheus=enable_prometheus, event_queue=dml_event_handler, diff --git a/tests/conftest.py b/tests/conftest.py index a934a03dd2..1936b7e319 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,17 +30,20 @@ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession from starlette.types import ASGIApp +import phoenix.trace.v1 as pb from phoenix.config import EXPORT_DIR from phoenix.core.model_schema_adapter import create_model_from_inferences from phoenix.db import models from phoenix.db.bulk_inserter import BulkInserter from phoenix.db.engines import aio_postgresql_engine, aio_sqlite_engine +from phoenix.db.insertion.helpers import DataManipulation from phoenix.inferences.inferences import EMPTY_INFERENCES from phoenix.pointcloud.umap_parameters import get_umap_parameters from phoenix.server.app import _db, create_app from phoenix.server.grpc_server import GrpcServer from phoenix.server.types import BatchedCaller, DbSessionFactory from phoenix.session.client import Client +from phoenix.trace.schemas import Span def pytest_addoption(parser: Parser) -> None: @@ -185,7 +188,6 @@ async def app( ) -> AsyncIterator[ASGIApp]: async with contextlib.AsyncExitStack() as stack: await stack.enter_async_context(patch_batched_caller()) - await stack.enter_async_context(patch_bulk_inserter()) await stack.enter_async_context(patch_grpc_server()) app = create_app( db=db, @@ -194,6 +196,7 @@ async def app( export_path=EXPORT_DIR, umap_params=get_umap_parameters(None), serve_ui=False, + bulk_inserter_factory=TestBulkInserter, ) manager = await stack.enter_async_context(LifespanManager(app)) yield manager.app @@ -203,31 +206,38 @@ async def app( def httpx_clients( app: ASGIApp, ) -> Tuple[httpx.Client, httpx.AsyncClient]: - class Transport(httpx.BaseTransport, httpx.AsyncBaseTransport): - def __init__(self, transport: httpx.ASGITransport) -> None: + class Transport(httpx.BaseTransport): + def __init__(self, app, asgi_transport): import nest_asyncio nest_asyncio.apply() - self.transport = transport + self.app = app + self.asgi_transport = asgi_transport def handle_request(self, request: Request) -> Response: - return asyncio.run(self.handle_async_request(request)) + response = asyncio.run(self.asgi_transport.handle_async_request(request)) - async def handle_async_request(self, request: Request) -> Response: - response = await self.transport.handle_async_request(request) + async def read_stream(): + content = b"" + async for chunk in response.stream: + content += chunk + return content + + content = asyncio.run(read_stream()) return Response( status_code=response.status_code, headers=response.headers, - content=b"".join([_ async for _ in response.stream]), + content=content, request=request, ) - transport = Transport(httpx.ASGITransport(app)) + asgi_transport = httpx.ASGITransport(app=app) + transport = Transport(httpx.ASGITransport(app), asgi_transport=asgi_transport) base_url = "http://test" return ( httpx.Client(transport=transport, base_url=base_url), - httpx.AsyncClient(transport=transport, base_url=base_url), + httpx.AsyncClient(transport=asgi_transport, base_url=base_url), ) @@ -266,15 +276,42 @@ async def patch_grpc_server() -> AsyncIterator[None]: setattr(cls, name, original) -@contextlib.asynccontextmanager -async def patch_bulk_inserter() -> AsyncIterator[None]: - cls = BulkInserter - original = cls.__init__ - name = original.__name__ - changes = {"sleep": 0.001, "retry_delay_sec": 0.001, "retry_allowance": 1000} - setattr(cls, name, lambda *_, **__: original(*_, **{**__, **changes})) - yield - setattr(cls, name, original) +class TestBulkInserter(BulkInserter): + async def __aenter__( + self, + ) -> Tuple[ + Callable[[Any], Awaitable[None]], + Callable[[Span, str], Awaitable[None]], + Callable[[pb.Evaluation], Awaitable[None]], + Callable[[DataManipulation], None], + ]: + # Return the overridden methods + return ( + self._enqueue_immediate, + self._queue_span_immediate, + self._queue_evaluation_immediate, + self._enqueue_operation_immediate, + ) + + async def __aexit__(self, *args: Any) -> None: + # No background tasks to cancel + pass + + async def _enqueue_immediate(self, *items: Any) -> None: + # Process items immediately + await self._queue_inserters.enqueue(*items) + async for event in self._queue_inserters.insert(): + self._event_queue.put(event) + + async def _enqueue_operation_immediate(self, operation: DataManipulation) -> None: + async with self._db() as session: + await operation(session) + + async def _queue_span_immediate(self, span: Span, project_name: str) -> None: + await self._insert_spans([(span, project_name)]) + + async def _queue_evaluation_immediate(self, evaluation: pb.Evaluation) -> None: + await self._insert_evaluations([evaluation]) @contextlib.asynccontextmanager diff --git a/tests/datasets/test_experiments.py b/tests/datasets/test_experiments.py index 795c28d902..55074f4a78 100644 --- a/tests/datasets/test_experiments.py +++ b/tests/datasets/test_experiments.py @@ -5,6 +5,7 @@ from unittest.mock import patch import httpx +import pytest from sqlalchemy import select from strawberry.relay import GlobalID @@ -36,8 +37,8 @@ async def test_run_experiment( simple_dataset: Any, dialect: str, ) -> None: - # if dialect == "postgresql": - # pytest.xfail("TODO: Convert this to an integration test") + if dialect == "postgresql": + pytest.xfail("TODO: Convert this to an integration test") async with db() as session: nonexistent_experiment = (await session.execute(select(models.Experiment))).scalar() @@ -102,7 +103,7 @@ def experiment_task(_) -> Dict[str, str]: # Wait until all evaluations are complete async def wait_for_evaluations(): - timeout = 30 + timeout = 15 interval = 0.5 total_wait = 0 while total_wait < timeout: @@ -180,8 +181,8 @@ async def test_run_experiment_with_llm_eval( simple_dataset: Any, dialect: str, ) -> None: - # if dialect == "postgresql": - # pytest.xfail("This test fails on PostgreSQL") + if dialect == "postgresql": + pytest.xfail("This test fails on PostgreSQL") async with db() as session: nonexistent_experiment = (await session.execute(select(models.Experiment))).scalar() @@ -292,8 +293,8 @@ async def test_run_evaluation( simple_dataset_with_one_experiment_run: Any, dialect: str, ) -> None: - # if dialect == "postgresql": - # pytest.xfail("This test fails on PostgreSQL") + if dialect == "postgresql": + pytest.xfail("This test fails on PostgreSQL") experiment = Experiment( id=str(GlobalID("Experiment", "0")),