diff --git a/projects/jupyter-server-ydoc/jupyter_server_ydoc/pytest_plugin.py b/projects/jupyter-server-ydoc/jupyter_server_ydoc/pytest_plugin.py index efdac129..afd2f6c3 100644 --- a/projects/jupyter-server-ydoc/jupyter_server_ydoc/pytest_plugin.py +++ b/projects/jupyter-server-ydoc/jupyter_server_ydoc/pytest_plugin.py @@ -14,9 +14,9 @@ from jupyter_server_ydoc.stores import SQLiteYStore from jupyter_ydoc import YNotebook, YUnicode from pycrdt_websocket import WebsocketProvider -from websockets import connect +from httpx_ws import aconnect_ws -from .test_utils import FakeContentsManager, FakeEventLogger, FakeFileIDManager +from .test_utils import FakeContentsManager, FakeEventLogger, FakeFileIDManager, Websocket @pytest.fixture @@ -126,8 +126,8 @@ def _inner(format: str, type: str, path: str) -> Any: @pytest.fixture def rtc_connect_awareness_client(jp_http_port, jp_base_url): async def _inner(room_id: str) -> Any: - return connect( - f"ws://127.0.0.1:{jp_http_port}{jp_base_url}api/collaboration/room/{room_id}" + return aconnect_ws( + f"http://127.0.0.1:{jp_http_port}{jp_base_url}api/collaboration/room/{room_id}" ) return _inner @@ -138,9 +138,10 @@ def rtc_connect_doc_client(jp_http_port, jp_base_url, rtc_fetch_session): async def _inner(format: str, type: str, path: str) -> Any: resp = await rtc_fetch_session(format, type, path) data = json.loads(resp.body.decode("utf-8")) - return connect( - f"ws://127.0.0.1:{jp_http_port}{jp_base_url}api/collaboration/room/{data['format']}:{data['type']}:{data['fileId']}?sessionId={data['sessionId']}" - ) + room_name = f"{data['format']}:{data['type']}:{data['fileId']}" + return aconnect_ws( + f"http://127.0.0.1:{jp_http_port}{jp_base_url}api/collaboration/room/{room_name}?sessionId={data['sessionId']}" + ), room_name return _inner @@ -162,9 +163,8 @@ def _on_document_change(target: str, e: Any) -> None: doc.observe(_on_document_change) - async with await rtc_connect_doc_client(format, type, path) as ws, WebsocketProvider( - doc.ydoc, ws - ): + websocket, room_name = await rtc_connect_doc_client(format, type, path) + async with websocket as ws, WebsocketProvider(doc.ydoc, Websocket(ws, room_name)): await event.wait() await sleep(0.1) diff --git a/projects/jupyter-server-ydoc/jupyter_server_ydoc/test_utils.py b/projects/jupyter-server-ydoc/jupyter_server_ydoc/test_utils.py index f1140322..fc0263bc 100644 --- a/projects/jupyter-server-ydoc/jupyter_server_ydoc/test_utils.py +++ b/projects/jupyter-server-ydoc/jupyter_server_ydoc/test_utils.py @@ -7,6 +7,7 @@ from typing import Any from jupyter_server import _tz as tz +from anyio import Lock class FakeFileIDManager: @@ -55,3 +56,32 @@ def save_content(self, model: dict[str, Any], path: str) -> dict: class FakeEventLogger: def emit(self, schema_id: str, data: dict) -> None: print(data) + + +class Websocket: + def __init__(self, websocket, path: str): + self._websocket = websocket + self._path = path + self._send_lock = Lock() + + @property + def path(self) -> str: + return self._path + + def __aiter__(self): + return self + + async def __anext__(self) -> bytes: + try: + message = await self.recv() + except Exception: + raise StopAsyncIteration() + return message + + async def send(self, message: bytes): + async with self._send_lock: + await self._websocket.send_bytes(message) + + async def recv(self) -> bytes: + b = await self._websocket.receive_bytes() + return bytes(b) diff --git a/projects/jupyter-server-ydoc/pyproject.toml b/projects/jupyter-server-ydoc/pyproject.toml index 7007980e..dce09301 100644 --- a/projects/jupyter-server-ydoc/pyproject.toml +++ b/projects/jupyter-server-ydoc/pyproject.toml @@ -45,7 +45,8 @@ test = [ "jupyter_server_fileid[test]", "pytest>=7.0", "pytest-cov", - "websockets", + "anyio", + "httpx-ws >=0.5.2", "importlib_metadata >=4.8.3; python_version<'3.10'", ] diff --git a/tests/test_documents.py b/tests/test_documents.py index 6573ad29..4b1428dc 100644 --- a/tests/test_documents.py +++ b/tests/test_documents.py @@ -12,6 +12,7 @@ import pytest from anyio import create_task_group, sleep from pycrdt_websocket import WebsocketProvider +from jupyter_server_ydoc.test_utils import Websocket jupyter_ydocs = {ep.name: ep.load() for ep in entry_points(group="jupyter_ydoc")} @@ -32,12 +33,12 @@ async def test_dirty( await rtc_create_file(file_path) jupyter_ydoc = jupyter_ydocs[file_type]() - async with await rtc_connect_doc_client(file_format, file_type, file_path) as ws: - async with WebsocketProvider(jupyter_ydoc.ydoc, ws): - for _ in range(2): - jupyter_ydoc.dirty = True - await sleep(rtc_document_save_delay * 1.5) - assert not jupyter_ydoc.dirty + websocket, room_name = await rtc_connect_doc_client(file_format, file_type, file_path) + async with websocket as ws, WebsocketProvider(jupyter_ydoc.ydoc, Websocket(ws, room_name)): + for _ in range(2): + jupyter_ydoc.dirty = True + await sleep(rtc_document_save_delay * 1.5) + assert not jupyter_ydoc.dirty async def cleanup(jp_serverapp): @@ -59,7 +60,8 @@ async def test_room_concurrent_initialization( await rtc_create_file(file_path) async def connect(file_format, file_type, file_path): - async with await rtc_connect_doc_client(file_format, file_type, file_path) as ws: + websocket, room_name = await rtc_connect_doc_client(file_format, file_type, file_path) + async with websocket as ws: pass t0 = time() @@ -84,7 +86,8 @@ async def test_room_sequential_opening( async def connect(file_format, file_type, file_path): t0 = time() - async with await rtc_connect_doc_client(file_format, file_type, file_path) as ws: + websocket, room_name = await rtc_connect_doc_client(file_format, file_type, file_path) + async with websocket as ws: pass t1 = time() return t1 - t0 diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 96aaded5..c8708c05 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -8,6 +8,7 @@ from typing import Any from jupyter_events.logger import EventLogger +from jupyter_server_ydoc.test_utils import Websocket from jupyter_ydoc import YUnicode from pycrdt_websocket import WebsocketProvider @@ -77,8 +78,9 @@ def _on_document_change(target: str, e: Any) -> None: doc = YUnicode() doc.observe(_on_document_change) - async with await rtc_connect_doc_client("text", "file", path) as ws, WebsocketProvider( - doc.ydoc, ws + websocket, room_name = await rtc_connect_doc_client("text", "file", path) + async with websocket as ws, WebsocketProvider( + doc.ydoc, Websocket(ws, room_name) ): await event.wait() await sleep(0.1) @@ -114,8 +116,9 @@ async def my_listener(logger: EventLogger, schema_id: str, data: dict) -> None: listener=my_listener, ) - async with await rtc_connect_doc_client("text", "file", path) as ws, WebsocketProvider( - doc.ydoc, ws + websocket, room_name = await rtc_connect_doc_client("text", "file", path) + async with websocket as ws, WebsocketProvider( + doc.ydoc, Websocket(ws, room_name) ): await event.wait() await sleep(0.1) @@ -147,8 +150,9 @@ def _on_document_change(target: str, e: Any) -> None: doc = YUnicode() doc.observe(_on_document_change) - async with await rtc_connect_doc_client("text", "file", path) as ws, WebsocketProvider( - doc.ydoc, ws + websocket, room_name = await rtc_connect_doc_client("text", "file", path) + async with websocket as ws, WebsocketProvider( + doc.ydoc, Websocket(ws, room_name) ): await event.wait() await sleep(0.1) @@ -173,8 +177,9 @@ async def my_listener(logger: EventLogger, schema_id: str, data: dict) -> None: path2, _ = await rtc_create_file("test2.txt", "test2") try: - async with await rtc_connect_doc_client("text2", "file2", path2) as ws, WebsocketProvider( - doc.ydoc, ws + websocket, room_name = await rtc_connect_doc_client("text2", "file2", path2) + async with websocket as ws, WebsocketProvider( + doc.ydoc, Websocket(ws, room_name) ): await event.wait() await sleep(0.1) @@ -182,8 +187,9 @@ async def my_listener(logger: EventLogger, schema_id: str, data: dict) -> None: pass try: - async with await rtc_connect_doc_client("text2", "file2", path2) as ws, WebsocketProvider( - doc.ydoc, ws + websocket, room_name = await rtc_connect_doc_client("text2", "file2", path2) + async with websocket as ws, WebsocketProvider( + doc.ydoc, Websocket(ws, room_name) ): await event.wait() await sleep(0.1)