Skip to content

Commit

Permalink
Experiment with locks
Browse files Browse the repository at this point in the history
  • Loading branch information
anticorrelator committed Sep 10, 2024
1 parent 02afda9 commit b3cfce7
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 27 deletions.
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ dependencies = [
]

[tool.hatch.envs.default.scripts]
tests = "pytest -n auto {args}"
tests = "pytest {args}"
coverage = "pytest --cov-report=term-missing --cov-config=pyproject.toml --cov=src/phoenix --cov=tests {args}"

[[tool.hatch.envs.test.matrix]]
Expand All @@ -262,7 +262,6 @@ addopts = [
"--doctest-modules",
"--new-first",
"--showlocals",
"--exitfirst",
]
testpaths = [
"tests",
Expand Down
4 changes: 2 additions & 2 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,8 @@ async def lifespan(_: FastAPI) -> AsyncIterator[Dict[str, Any]]:
for callback in startup_callbacks:
if isinstance((res := callback()), Awaitable):
await res
global DB_MUTEX
DB_MUTEX = asyncio.Lock() if db.dialect is SupportedSQLDialect.SQLITE else None
# global DB_MUTEX
# DB_MUTEX = asyncio.Lock() if db.dialect is SupportedSQLDialect.SQLITE else None
async with AsyncExitStack() as stack:
(
enqueue,
Expand Down
49 changes: 29 additions & 20 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import httpx
import pytest
import sqlean
from _pytest.config import Config, Parser
from _pytest.fixtures import SubRequest
from _pytest.terminal import TerminalReporter
Expand All @@ -27,7 +28,8 @@
from psycopg import Connection
from pytest_postgresql import factories
from sqlalchemy import make_url
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from starlette.types import ASGIApp

from phoenix.config import EXPORT_DIR
Expand Down Expand Up @@ -64,20 +66,21 @@ def pytest_terminal_summary(

xfail_threshold = 12 # our tests are currently quite unreliable

terminalreporter.write_sep("=", f"xfail threshold: {xfail_threshold}")
terminalreporter.write_sep("=", f"xpasses: {xpasses}, xfails: {xfails}")
if config.getoption("--run-postgres"):
terminalreporter.write_sep("=", f"xfail threshold: {xfail_threshold}")
terminalreporter.write_sep("=", f"xpasses: {xpasses}, xfails: {xfails}")

if exitstatus == pytest.ExitCode.OK:
if xfails < xfail_threshold:
terminalreporter.write_sep(
"=", "Within xfail threshold. Passing the test suite.", green=True
)
terminalreporter._session.exitstatus = pytest.ExitCode.OK
else:
terminalreporter.write_sep(
"=", "Too many flaky tests. Failing the test suite.", red=True
)
terminalreporter._session.exitstatus = pytest.ExitCode.TESTS_FAILED
if exitstatus == pytest.ExitCode.OK:
if xfails < xfail_threshold:
terminalreporter.write_sep(
"=", "Within xfail threshold. Passing the test suite.", green=True
)
terminalreporter._session.exitstatus = pytest.ExitCode.OK
else:
terminalreporter.write_sep(
"=", "Too many flaky tests. Failing the test suite.", red=True
)
terminalreporter._session.exitstatus = pytest.ExitCode.TESTS_FAILED


def pytest_collection_modifyitems(config: Config, items: List[Any]) -> None:
Expand Down Expand Up @@ -126,7 +129,7 @@ async def postgresql_url(postgresql_connection: Connection) -> AsyncIterator[URL
yield make_url(f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}")


@pytest.fixture
@pytest.fixture()
async def postgresql_engine(postgresql_url: URL) -> AsyncIterator[AsyncEngine]:
engine = aio_postgresql_engine(postgresql_url, migrate=False)
async with engine.begin() as conn:
Expand All @@ -140,16 +143,21 @@ def dialect(request: SubRequest) -> str:
return request.param


def create_async_sqlite_engine() -> sessionmaker:
return create_async_engine("sqlite+aiosqlite:///:memory:", module=sqlean)


@pytest.fixture
async def sqlite_engine() -> AsyncIterator[AsyncEngine]:
engine = aio_sqlite_engine(make_url("sqlite+aiosqlite://"), migrate=False, shared_cache=False)
# engine = create_async_sqlite_engine()
async with engine.begin() as conn:
await conn.run_sync(models.Base.metadata.create_all)
yield engine
await engine.dispose()


@pytest.fixture
@pytest.fixture(scope="function")
def db(
request: SubRequest,
dialect: str,
Expand All @@ -162,12 +170,14 @@ def db(


# def _db_with_lock(engine: AsyncEngine) -> DbSessionFactory:
# lock, db = asyncio.Lock(), _db(engine)
# lock = threading.Lock()
# db = _db(engine)

# @contextlib.asynccontextmanager
# async def factory() -> AsyncIterator[AsyncSession]:
# async with lock, db() as session:
# yield session
# with lock:
# async with db() as session:
# yield session

# return DbSessionFactory(db=factory, dialect=engine.dialect.name)

Expand Down Expand Up @@ -227,7 +237,6 @@ async def loop() -> AbstractEventLoop:
@pytest.fixture
def httpx_clients(
app: ASGIApp,
loop: AbstractEventLoop,
) -> Tuple[httpx.Client, httpx.AsyncClient]:
class Transport(httpx.BaseTransport, httpx.AsyncBaseTransport):
def __init__(self, transport: httpx.ASGITransport) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/datasets/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def experiment_task(_) -> Dict[str, str]:
evaluators={f"{i:02}": e for i, e in enumerate(evaluators)},
print_summary=False,
)
time.sleep(1) # Wait for the entire experiment to be run
experiment_id = from_global_id_with_expected_type(
GlobalID.from_id(experiment.id), "Experiment"
)
Expand Down Expand Up @@ -262,6 +261,7 @@ async def test_run_evaluation(
)
with patch("phoenix.experiments.functions._phoenix_clients", return_value=httpx_clients):
evaluate_experiment(experiment, evaluators=[lambda _: _])
time.sleep(1) # Wait for the evaluations to be inserted
async with db() as session:
evaluations = list(await session.scalars(select(models.ExperimentRunAnnotation)))
assert len(evaluations) == 1
Expand Down Expand Up @@ -378,6 +378,6 @@ def can_i_evaluate_with_everything_in_any_order(

async def test_get_experiment_client_method(px_client, simple_dataset_with_one_experiment_run):
experiment_gid = GlobalID("Experiment", "0")
experiment = px_client.get_experiment_runs(experiment_id=experiment_gid)
experiment = px_client.get_experiment(experiment_id=experiment_gid)
assert experiment
assert isinstance(experiment, Experiment)
2 changes: 1 addition & 1 deletion tests/server/api/routers/v1/test_spans.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def test_span_round_tripping_with_docs(
orig_count = len(orig_docs)
assert orig_count
px_client.log_traces(TraceDataset(df))
time.sleep(0.1) # Wait for the spans to be inserted
time.sleep(1) # Wait for the spans to be inserted
docs = cast(pd.DataFrame, px_client.query_spans(doc_query))
new_count = len(docs)
assert new_count
Expand Down

0 comments on commit b3cfce7

Please sign in to comment.