Skip to content

Commit

Permalink
Merge pull request #35 from davidbrochart/fix-yroom
Browse files Browse the repository at this point in the history
Fix room async context manager
  • Loading branch information
Zsailer authored Apr 29, 2024
2 parents 3764238 + 4e61046 commit a27dca5
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
mypy pycrdt_websocket tests
- name: Run tests
run: |
pytest -v --color=yes
pytest -v --color=yes --timeout=60
check_release:
runs-on: ubuntu-latest
Expand Down
3 changes: 3 additions & 0 deletions pycrdt_websocket/yroom.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ async def start(
if from_context_manager:
task_status.started()
self.started.set()
self._update_send_stream, self._update_receive_stream = create_memory_object_stream(
max_buffer_size=65536
)
assert self._task_group is not None
self._task_group.start_soon(self._stopped.wait)
self._task_group.start_soon(self._watch_ready)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ test = [
"mypy !=1.10.0", # see https://github.com/python/mypy/issues/17166
"pre-commit",
"pytest",
"pytest-timeout",
"httpx-ws >=0.5.2",
"hypercorn >=0.16.0",
"trio >=0.25.0",
Expand Down
59 changes: 47 additions & 12 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,36 @@
from hypercorn import Config
from pycrdt import Doc
from sniffio import current_async_library
from utils import StartStopContextManager, Websocket, ensure_server_running
from utils import StartStopContextManager, Websocket, connected_websockets, ensure_server_running

from pycrdt_websocket import ASGIServer, WebsocketProvider, WebsocketServer
from pycrdt_websocket import ASGIServer, WebsocketProvider, WebsocketServer, YRoom


@pytest.fixture(params=("websocket_server_context_manager", "websocket_server_start_stop"))
def websocket_server_api(request):
return request.param


@pytest.fixture(params=("websocket_provider_context_manager", "websocket_provider_start_stop"))
def websocket_provider_api(request):
return request.param


@pytest.fixture(params=("yroom_context_manager", "yroom_start_stop"))
def yroom_api(request):
return request.param


@pytest.fixture(params=("real_websocket",))
def websocket_provider_connect(request):
return request.param


@pytest.fixture(params=("ystore_context_manager", "ystore_start_stop"))
def ystore_api(request):
return request.param


@pytest.fixture
async def yws_server(request, unused_tcp_port, websocket_server_api):
try:
Expand Down Expand Up @@ -50,31 +70,32 @@ async def yws_server(request, unused_tcp_port, websocket_server_api):
pass


@pytest.fixture(params=("websocket_provider_context_manager", "websocket_provider_start_stop"))
def websocket_provider_api(request):
return request.param


@pytest.fixture
def yws_provider_factory(room_name, websocket_provider_api):
def yws_provider_factory(room_name, websocket_provider_api, websocket_provider_connect):
@asynccontextmanager
async def factory():
ydoc = Doc()
async with aconnect_ws(f"http://localhost:{pytest.port}/{room_name}") as websocket:
if websocket_provider_connect == "real_websocket":
server_websocket = None
connect = aconnect_ws(f"http://localhost:{pytest.port}/{room_name}")
else:
server_websocket, connect = connected_websockets()
async with connect as websocket:
async with create_task_group() as tg:
websocket_provider = WebsocketProvider(ydoc, Websocket(websocket, room_name))
if websocket_provider_api == "websocket_provider_start_stop":
websocket_provider = StartStopContextManager(websocket_provider, tg)
async with websocket_provider as websocket_provider:
yield ydoc
yield ydoc, server_websocket

return factory


@pytest.fixture
async def yws_provider(yws_provider_factory):
async with yws_provider_factory() as ydoc:
yield ydoc
async with yws_provider_factory() as provider:
ydoc, server_websocket = provider
yield ydoc, server_websocket


@pytest.fixture
Expand All @@ -83,6 +104,20 @@ async def yws_providers(request, yws_provider_factory):
yield [yws_provider_factory() for idx in range(number)]


@pytest.fixture
async def yroom(request, yroom_api):
async with create_task_group() as tg:
try:
kwargs = request.param
except AttributeError:
kwargs = {}
room = YRoom(**kwargs)
if yroom_api == "yroom_start_stop":
room = StartStopContextManager(room, tg)
async with room as room:
yield room


@pytest.fixture
def yjs_client(request):
client_id = request.param
Expand Down
6 changes: 4 additions & 2 deletions tests/test_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
async def test_asgi(yws_server, yws_providers):
yws_provider1, yws_provider2 = yws_providers
# client 1
async with yws_provider1 as ydoc1:
async with yws_provider1 as yws_provider1:
ydoc1, _ = yws_provider1
ydoc1["map"] = ymap1 = Map()
ymap1["key"] = "value"
await sleep(0.1)

# client 2
async with yws_provider2 as ydoc2:
async with yws_provider2 as yws_provider2:
ydoc2, _ = yws_provider2
ymap2 = ydoc2.get("map", type=Map)
await sleep(0.1)
assert str(ymap2) == '{"key":"value"}'
4 changes: 2 additions & 2 deletions tests/test_pycrdt_yjs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def watch(ydata, key: str | None = None, timeout: float = 1.0):

@pytest.mark.parametrize("yjs_client", [0], indirect=True)
async def test_pycrdt_yjs_0(yws_server, yws_provider, yjs_client):
ydoc = yws_provider
ydoc, _ = yws_provider
ydoc["map"] = ymap = Map()
for v_in in range(10):
ymap["in"] = float(v_in)
Expand All @@ -49,7 +49,7 @@ async def test_pycrdt_yjs_0(yws_server, yws_provider, yjs_client):

@pytest.mark.parametrize("yjs_client", [1], indirect=True)
async def test_pycrdt_yjs_1(yws_server, yws_provider, yjs_client):
ydoc = yws_provider
ydoc, _ = yws_provider
ydoc["cells"] = ycells = Array()
ydoc["state"] = ystate = Map()
ycells_change = watch(ycells)
Expand Down
29 changes: 27 additions & 2 deletions tests/test_yroom.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,39 @@
import pytest
from anyio import TASK_STATUS_IGNORED, sleep
from anyio import TASK_STATUS_IGNORED, create_task_group, sleep
from anyio.abc import TaskStatus
from pycrdt import Map
from utils import Websocket

from pycrdt_websocket import exception_logger
from pycrdt_websocket.yroom import YRoom

pytestmark = pytest.mark.anyio


@pytest.mark.parametrize("websocket_provider_connect", ["fake_websocket"], indirect=True)
@pytest.mark.parametrize("yws_providers", [2], indirect=True)
async def test_yroom(yroom, yws_providers, websocket_provider_connect, room_name):
async with create_task_group() as tg:
yws_provider1, yws_provider2 = yws_providers
# client 1
async with yws_provider1 as yws_provider1:
ydoc1, server_ws1 = yws_provider1
tg.start_soon(yroom.serve, Websocket(server_ws1, room_name))
ydoc1["map"] = ymap1 = Map()
ymap1["key"] = "value"
await sleep(0.1)

# client 2
async with yws_provider2 as yws_provider2:
ydoc2, server_ws2 = yws_provider2
tg.start_soon(yroom.serve, Websocket(server_ws2, room_name))
ymap2 = ydoc2.get("map", type=Map)
await sleep(0.1)

assert str(ymap2) == '{"key":"value"}'
tg.cancel_scope.cancel()


@pytest.mark.parametrize("websocket_server_api", ["websocket_server_start_stop"], indirect=True)
@pytest.mark.parametrize("yws_server", [{"exception_handler": exception_logger}], indirect=True)
async def test_yroom_restart(yws_server, yws_provider):
Expand All @@ -19,7 +44,7 @@ async def raise_error(task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
task_status.started()
raise RuntimeError("foo")

yroom.ydoc = yws_provider
yroom.ydoc, _ = yws_provider
await server.start_room(yroom)
yroom.ydoc["map"] = ymap1 = Map()
ymap1["key"] = "value"
Expand Down
41 changes: 40 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from anyio import Lock, connect_tcp
from anyio import Lock, connect_tcp, create_memory_object_stream
from pycrdt import Array, Doc


Expand Down Expand Up @@ -60,6 +60,45 @@ async def recv(self) -> bytes:
return bytes(b)


class ClientWebsocket:
def __init__(self, server_websocket: "ServerWebsocket"):
self.server_websocket = server_websocket
self.send_stream, self.receive_stream = create_memory_object_stream[bytes](65536)

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
pass

async def send_bytes(self, message: bytes) -> None:
await self.server_websocket.send_stream.send(message)

async def receive_bytes(self) -> bytes:
return await self.receive_stream.receive()


class ServerWebsocket:
client_websocket: ClientWebsocket | None = None

def __init__(self):
self.send_stream, self.receive_stream = create_memory_object_stream[bytes](65536)

async def send_bytes(self, message: bytes) -> None:
assert self.client_websocket is not None
await self.client_websocket.send_stream.send(message)

async def receive_bytes(self) -> bytes:
return await self.receive_stream.receive()


def connected_websockets() -> tuple[ServerWebsocket, ClientWebsocket]:
server_websocket = ServerWebsocket()
client_websocket = ClientWebsocket(server_websocket)
server_websocket.client_websocket = client_websocket
return server_websocket, client_websocket


async def ensure_server_running(host: str, port: int) -> None:
while True:
try:
Expand Down

0 comments on commit a27dca5

Please sign in to comment.