diff --git a/pyproject.toml b/pyproject.toml index 29dd4b9c3d9..acb98f5bbf7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -248,7 +248,7 @@ dependencies = [ ] [tool.hatch.envs.default.scripts] -tests = "pytest -n auto {args}" +tests = "pytest {args}" coverage = "pytest --cov-report=term-missing --cov-config=pyproject.toml --cov=src/phoenix --cov=tests {args}" [[tool.hatch.envs.test.matrix]] @@ -262,7 +262,6 @@ addopts = [ "--doctest-modules", "--new-first", "--showlocals", - "--exitfirst", ] testpaths = [ "tests", diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 9a71aa1c329..63f848a3c7c 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -393,8 +393,8 @@ async def lifespan(_: FastAPI) -> AsyncIterator[Dict[str, Any]]: for callback in startup_callbacks: if isinstance((res := callback()), Awaitable): await res - global DB_MUTEX - DB_MUTEX = asyncio.Lock() if db.dialect is SupportedSQLDialect.SQLITE else None + # global DB_MUTEX + # DB_MUTEX = asyncio.Lock() if db.dialect is SupportedSQLDialect.SQLITE else None async with AsyncExitStack() as stack: ( enqueue, diff --git a/tests/conftest.py b/tests/conftest.py index 3bc6144de2a..a8e786b591c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,7 @@ import httpx import pytest +import sqlean from _pytest.config import Config, Parser from _pytest.fixtures import SubRequest from _pytest.terminal import TerminalReporter @@ -27,7 +28,8 @@ from psycopg import Connection from pytest_postgresql import factories from sqlalchemy import make_url -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker from starlette.types import ASGIApp from phoenix.config import EXPORT_DIR @@ -64,20 +66,21 @@ def pytest_terminal_summary( xfail_threshold = 12 # our tests are currently quite unreliable - terminalreporter.write_sep("=", f"xfail threshold: {xfail_threshold}") - terminalreporter.write_sep("=", f"xpasses: {xpasses}, xfails: {xfails}") + if config.getoption("--run-postgres"): + terminalreporter.write_sep("=", f"xfail threshold: {xfail_threshold}") + terminalreporter.write_sep("=", f"xpasses: {xpasses}, xfails: {xfails}") - if exitstatus == pytest.ExitCode.OK: - if xfails < xfail_threshold: - terminalreporter.write_sep( - "=", "Within xfail threshold. Passing the test suite.", green=True - ) - terminalreporter._session.exitstatus = pytest.ExitCode.OK - else: - terminalreporter.write_sep( - "=", "Too many flaky tests. Failing the test suite.", red=True - ) - terminalreporter._session.exitstatus = pytest.ExitCode.TESTS_FAILED + if exitstatus == pytest.ExitCode.OK: + if xfails < xfail_threshold: + terminalreporter.write_sep( + "=", "Within xfail threshold. Passing the test suite.", green=True + ) + terminalreporter._session.exitstatus = pytest.ExitCode.OK + else: + terminalreporter.write_sep( + "=", "Too many flaky tests. Failing the test suite.", red=True + ) + terminalreporter._session.exitstatus = pytest.ExitCode.TESTS_FAILED def pytest_collection_modifyitems(config: Config, items: List[Any]) -> None: @@ -126,7 +129,7 @@ async def postgresql_url(postgresql_connection: Connection) -> AsyncIterator[URL yield make_url(f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}") -@pytest.fixture +@pytest.fixture() async def postgresql_engine(postgresql_url: URL) -> AsyncIterator[AsyncEngine]: engine = aio_postgresql_engine(postgresql_url, migrate=False) async with engine.begin() as conn: @@ -140,16 +143,21 @@ def dialect(request: SubRequest) -> str: return request.param +def create_async_sqlite_engine() -> sessionmaker: + return create_async_engine("sqlite+aiosqlite:///:memory:", module=sqlean) + + @pytest.fixture async def sqlite_engine() -> AsyncIterator[AsyncEngine]: engine = aio_sqlite_engine(make_url("sqlite+aiosqlite://"), migrate=False, shared_cache=False) + # engine = create_async_sqlite_engine() async with engine.begin() as conn: await conn.run_sync(models.Base.metadata.create_all) yield engine await engine.dispose() -@pytest.fixture +@pytest.fixture(scope="function") def db( request: SubRequest, dialect: str, @@ -162,12 +170,14 @@ def db( # def _db_with_lock(engine: AsyncEngine) -> DbSessionFactory: -# lock, db = asyncio.Lock(), _db(engine) +# lock = threading.Lock() +# db = _db(engine) # @contextlib.asynccontextmanager # async def factory() -> AsyncIterator[AsyncSession]: -# async with lock, db() as session: -# yield session +# with lock: +# async with db() as session: +# yield session # return DbSessionFactory(db=factory, dialect=engine.dialect.name) @@ -227,7 +237,6 @@ async def loop() -> AbstractEventLoop: @pytest.fixture def httpx_clients( app: ASGIApp, - loop: AbstractEventLoop, ) -> Tuple[httpx.Client, httpx.AsyncClient]: class Transport(httpx.BaseTransport, httpx.AsyncBaseTransport): def __init__(self, transport: httpx.ASGITransport) -> None: diff --git a/tests/datasets/test_experiments.py b/tests/datasets/test_experiments.py index f2e02063cf6..d036ca3b174 100644 --- a/tests/datasets/test_experiments.py +++ b/tests/datasets/test_experiments.py @@ -89,7 +89,6 @@ def experiment_task(_) -> Dict[str, str]: evaluators={f"{i:02}": e for i, e in enumerate(evaluators)}, print_summary=False, ) - time.sleep(1) # Wait for the entire experiment to be run experiment_id = from_global_id_with_expected_type( GlobalID.from_id(experiment.id), "Experiment" ) @@ -262,6 +261,7 @@ async def test_run_evaluation( ) with patch("phoenix.experiments.functions._phoenix_clients", return_value=httpx_clients): evaluate_experiment(experiment, evaluators=[lambda _: _]) + time.sleep(1) # Wait for the evaluations to be inserted async with db() as session: evaluations = list(await session.scalars(select(models.ExperimentRunAnnotation))) assert len(evaluations) == 1 @@ -378,6 +378,6 @@ def can_i_evaluate_with_everything_in_any_order( async def test_get_experiment_client_method(px_client, simple_dataset_with_one_experiment_run): experiment_gid = GlobalID("Experiment", "0") - experiment = px_client.get_experiment_runs(experiment_id=experiment_gid) + experiment = px_client.get_experiment(experiment_id=experiment_gid) assert experiment assert isinstance(experiment, Experiment) diff --git a/tests/server/api/routers/v1/test_spans.py b/tests/server/api/routers/v1/test_spans.py index 83032f92100..8a6e363ede7 100644 --- a/tests/server/api/routers/v1/test_spans.py +++ b/tests/server/api/routers/v1/test_spans.py @@ -31,7 +31,7 @@ async def test_span_round_tripping_with_docs( orig_count = len(orig_docs) assert orig_count px_client.log_traces(TraceDataset(df)) - time.sleep(0.1) # Wait for the spans to be inserted + time.sleep(1) # Wait for the spans to be inserted docs = cast(pd.DataFrame, px_client.query_spans(doc_query)) new_count = len(docs) assert new_count