Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix room async context manager #35

Merged
merged 2 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
mypy pycrdt_websocket tests
- name: Run tests
run: |
pytest -v --color=yes
pytest -v --color=yes --timeout=60
check_release:
runs-on: ubuntu-latest
Expand Down
3 changes: 3 additions & 0 deletions pycrdt_websocket/yroom.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ async def start(
if from_context_manager:
task_status.started()
self.started.set()
self._update_send_stream, self._update_receive_stream = create_memory_object_stream(
max_buffer_size=65536
)
assert self._task_group is not None
self._task_group.start_soon(self._stopped.wait)
self._task_group.start_soon(self._watch_ready)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ test = [
"mypy !=1.10.0", # see https://github.com/python/mypy/issues/17166
"pre-commit",
"pytest",
"pytest-timeout",
"httpx-ws >=0.5.2",
"hypercorn >=0.16.0",
"trio >=0.25.0",
Expand Down
59 changes: 47 additions & 12 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,36 @@
from hypercorn import Config
from pycrdt import Doc
from sniffio import current_async_library
from utils import StartStopContextManager, Websocket, ensure_server_running
from utils import StartStopContextManager, Websocket, connected_websockets, ensure_server_running

from pycrdt_websocket import ASGIServer, WebsocketProvider, WebsocketServer
from pycrdt_websocket import ASGIServer, WebsocketProvider, WebsocketServer, YRoom


@pytest.fixture(params=("websocket_server_context_manager", "websocket_server_start_stop"))
def websocket_server_api(request):
return request.param


@pytest.fixture(params=("websocket_provider_context_manager", "websocket_provider_start_stop"))
def websocket_provider_api(request):
return request.param


@pytest.fixture(params=("yroom_context_manager", "yroom_start_stop"))
def yroom_api(request):
return request.param


@pytest.fixture(params=("real_websocket",))
def websocket_provider_connect(request):
return request.param


@pytest.fixture(params=("ystore_context_manager", "ystore_start_stop"))
def ystore_api(request):
return request.param


@pytest.fixture
async def yws_server(request, unused_tcp_port, websocket_server_api):
try:
Expand Down Expand Up @@ -50,31 +70,32 @@ async def yws_server(request, unused_tcp_port, websocket_server_api):
pass


@pytest.fixture(params=("websocket_provider_context_manager", "websocket_provider_start_stop"))
def websocket_provider_api(request):
return request.param


@pytest.fixture
def yws_provider_factory(room_name, websocket_provider_api):
def yws_provider_factory(room_name, websocket_provider_api, websocket_provider_connect):
@asynccontextmanager
async def factory():
ydoc = Doc()
async with aconnect_ws(f"http://localhost:{pytest.port}/{room_name}") as websocket:
if websocket_provider_connect == "real_websocket":
server_websocket = None
connect = aconnect_ws(f"http://localhost:{pytest.port}/{room_name}")
else:
server_websocket, connect = connected_websockets()
async with connect as websocket:
async with create_task_group() as tg:
websocket_provider = WebsocketProvider(ydoc, Websocket(websocket, room_name))
if websocket_provider_api == "websocket_provider_start_stop":
websocket_provider = StartStopContextManager(websocket_provider, tg)
async with websocket_provider as websocket_provider:
yield ydoc
yield ydoc, server_websocket

return factory


@pytest.fixture
async def yws_provider(yws_provider_factory):
async with yws_provider_factory() as ydoc:
yield ydoc
async with yws_provider_factory() as provider:
ydoc, server_websocket = provider
yield ydoc, server_websocket


@pytest.fixture
Expand All @@ -83,6 +104,20 @@ async def yws_providers(request, yws_provider_factory):
yield [yws_provider_factory() for idx in range(number)]


@pytest.fixture
async def yroom(request, yroom_api):
async with create_task_group() as tg:
try:
kwargs = request.param
except AttributeError:
kwargs = {}
room = YRoom(**kwargs)
if yroom_api == "yroom_start_stop":
room = StartStopContextManager(room, tg)
async with room as room:
yield room


@pytest.fixture
def yjs_client(request):
client_id = request.param
Expand Down
6 changes: 4 additions & 2 deletions tests/test_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
async def test_asgi(yws_server, yws_providers):
yws_provider1, yws_provider2 = yws_providers
# client 1
async with yws_provider1 as ydoc1:
async with yws_provider1 as yws_provider1:
ydoc1, _ = yws_provider1
ydoc1["map"] = ymap1 = Map()
ymap1["key"] = "value"
await sleep(0.1)

# client 2
async with yws_provider2 as ydoc2:
async with yws_provider2 as yws_provider2:
ydoc2, _ = yws_provider2
ymap2 = ydoc2.get("map", type=Map)
await sleep(0.1)
assert str(ymap2) == '{"key":"value"}'
4 changes: 2 additions & 2 deletions tests/test_pycrdt_yjs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def watch(ydata, key: str | None = None, timeout: float = 1.0):

@pytest.mark.parametrize("yjs_client", [0], indirect=True)
async def test_pycrdt_yjs_0(yws_server, yws_provider, yjs_client):
ydoc = yws_provider
ydoc, _ = yws_provider
ydoc["map"] = ymap = Map()
for v_in in range(10):
ymap["in"] = float(v_in)
Expand All @@ -49,7 +49,7 @@ async def test_pycrdt_yjs_0(yws_server, yws_provider, yjs_client):

@pytest.mark.parametrize("yjs_client", [1], indirect=True)
async def test_pycrdt_yjs_1(yws_server, yws_provider, yjs_client):
ydoc = yws_provider
ydoc, _ = yws_provider
ydoc["cells"] = ycells = Array()
ydoc["state"] = ystate = Map()
ycells_change = watch(ycells)
Expand Down
29 changes: 27 additions & 2 deletions tests/test_yroom.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,39 @@
import pytest
from anyio import TASK_STATUS_IGNORED, sleep
from anyio import TASK_STATUS_IGNORED, create_task_group, sleep
from anyio.abc import TaskStatus
from pycrdt import Map
from utils import Websocket

from pycrdt_websocket import exception_logger
from pycrdt_websocket.yroom import YRoom

pytestmark = pytest.mark.anyio


@pytest.mark.parametrize("websocket_provider_connect", ["fake_websocket"], indirect=True)
@pytest.mark.parametrize("yws_providers", [2], indirect=True)
async def test_yroom(yroom, yws_providers, websocket_provider_connect, room_name):
async with create_task_group() as tg:
yws_provider1, yws_provider2 = yws_providers
# client 1
async with yws_provider1 as yws_provider1:
ydoc1, server_ws1 = yws_provider1
tg.start_soon(yroom.serve, Websocket(server_ws1, room_name))
ydoc1["map"] = ymap1 = Map()
ymap1["key"] = "value"
await sleep(0.1)

# client 2
async with yws_provider2 as yws_provider2:
ydoc2, server_ws2 = yws_provider2
tg.start_soon(yroom.serve, Websocket(server_ws2, room_name))
ymap2 = ydoc2.get("map", type=Map)
await sleep(0.1)

assert str(ymap2) == '{"key":"value"}'
tg.cancel_scope.cancel()


@pytest.mark.parametrize("websocket_server_api", ["websocket_server_start_stop"], indirect=True)
@pytest.mark.parametrize("yws_server", [{"exception_handler": exception_logger}], indirect=True)
async def test_yroom_restart(yws_server, yws_provider):
Expand All @@ -19,7 +44,7 @@ async def raise_error(task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
task_status.started()
raise RuntimeError("foo")

yroom.ydoc = yws_provider
yroom.ydoc, _ = yws_provider
await server.start_room(yroom)
yroom.ydoc["map"] = ymap1 = Map()
ymap1["key"] = "value"
Expand Down
41 changes: 40 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from anyio import Lock, connect_tcp
from anyio import Lock, connect_tcp, create_memory_object_stream
from pycrdt import Array, Doc


Expand Down Expand Up @@ -60,6 +60,45 @@ async def recv(self) -> bytes:
return bytes(b)


class ClientWebsocket:
def __init__(self, server_websocket: "ServerWebsocket"):
self.server_websocket = server_websocket
self.send_stream, self.receive_stream = create_memory_object_stream[bytes](65536)

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
pass

async def send_bytes(self, message: bytes) -> None:
await self.server_websocket.send_stream.send(message)

async def receive_bytes(self) -> bytes:
return await self.receive_stream.receive()


class ServerWebsocket:
client_websocket: ClientWebsocket | None = None

def __init__(self):
self.send_stream, self.receive_stream = create_memory_object_stream[bytes](65536)

async def send_bytes(self, message: bytes) -> None:
assert self.client_websocket is not None
await self.client_websocket.send_stream.send(message)

async def receive_bytes(self) -> bytes:
return await self.receive_stream.receive()


def connected_websockets() -> tuple[ServerWebsocket, ClientWebsocket]:
server_websocket = ServerWebsocket()
client_websocket = ClientWebsocket(server_websocket)
server_websocket.client_websocket = client_websocket
return server_websocket, client_websocket


async def ensure_server_running(host: str, port: int) -> None:
while True:
try:
Expand Down
Loading