diff --git a/projects/jupyter-server-ydoc/jupyter_server_ydoc/pytest_plugin.py b/projects/jupyter-server-ydoc/jupyter_server_ydoc/pytest_plugin.py index efdac129..8bb8479e 100644 --- a/projects/jupyter-server-ydoc/jupyter_server_ydoc/pytest_plugin.py +++ b/projects/jupyter-server-ydoc/jupyter_server_ydoc/pytest_plugin.py @@ -9,14 +9,19 @@ import nbformat import pytest +from httpx_ws import aconnect_ws from jupyter_server_ydoc.loaders import FileLoader from jupyter_server_ydoc.rooms import DocumentRoom from jupyter_server_ydoc.stores import SQLiteYStore from jupyter_ydoc import YNotebook, YUnicode from pycrdt_websocket import WebsocketProvider -from websockets import connect -from .test_utils import FakeContentsManager, FakeEventLogger, FakeFileIDManager +from .test_utils import ( + FakeContentsManager, + FakeEventLogger, + FakeFileIDManager, + Websocket, +) @pytest.fixture @@ -126,8 +131,8 @@ def _inner(format: str, type: str, path: str) -> Any: @pytest.fixture def rtc_connect_awareness_client(jp_http_port, jp_base_url): async def _inner(room_id: str) -> Any: - return connect( - f"ws://127.0.0.1:{jp_http_port}{jp_base_url}api/collaboration/room/{room_id}" + return aconnect_ws( + f"http://127.0.0.1:{jp_http_port}{jp_base_url}api/collaboration/room/{room_id}" ) return _inner @@ -138,8 +143,12 @@ def rtc_connect_doc_client(jp_http_port, jp_base_url, rtc_fetch_session): async def _inner(format: str, type: str, path: str) -> Any: resp = await rtc_fetch_session(format, type, path) data = json.loads(resp.body.decode("utf-8")) - return connect( - f"ws://127.0.0.1:{jp_http_port}{jp_base_url}api/collaboration/room/{data['format']}:{data['type']}:{data['fileId']}?sessionId={data['sessionId']}" + room_name = f"{data['format']}:{data['type']}:{data['fileId']}" + return ( + aconnect_ws( + f"http://127.0.0.1:{jp_http_port}{jp_base_url}api/collaboration/room/{room_name}?sessionId={data['sessionId']}" + ), + room_name, ) return _inner @@ -162,9 +171,8 @@ def _on_document_change(target: str, e: Any) -> None: doc.observe(_on_document_change) - async with await rtc_connect_doc_client(format, type, path) as ws, WebsocketProvider( - doc.ydoc, ws - ): + websocket, room_name = await rtc_connect_doc_client(format, type, path) + async with websocket as ws, WebsocketProvider(doc.ydoc, Websocket(ws, room_name)): await event.wait() await sleep(0.1) diff --git a/projects/jupyter-server-ydoc/jupyter_server_ydoc/test_utils.py b/projects/jupyter-server-ydoc/jupyter_server_ydoc/test_utils.py index f1140322..24528004 100644 --- a/projects/jupyter-server-ydoc/jupyter_server_ydoc/test_utils.py +++ b/projects/jupyter-server-ydoc/jupyter_server_ydoc/test_utils.py @@ -6,6 +6,7 @@ from datetime import datetime from typing import Any +from anyio import Lock from jupyter_server import _tz as tz @@ -55,3 +56,32 @@ def save_content(self, model: dict[str, Any], path: str) -> dict: class FakeEventLogger: def emit(self, schema_id: str, data: dict) -> None: print(data) + + +class Websocket: + def __init__(self, websocket: Any, path: str): + self._websocket = websocket + self._path = path + self._send_lock = Lock() + + @property + def path(self) -> str: + return self._path + + def __aiter__(self): + return self + + async def __anext__(self) -> bytes: + try: + message = await self.recv() + except Exception: + raise StopAsyncIteration() + return message + + async def send(self, message: bytes) -> None: + async with self._send_lock: + await self._websocket.send_bytes(message) + + async def recv(self) -> bytes: + b = await self._websocket.receive_bytes() + return bytes(b) diff --git a/projects/jupyter-server-ydoc/pyproject.toml b/projects/jupyter-server-ydoc/pyproject.toml index 7007980e..dce09301 100644 --- a/projects/jupyter-server-ydoc/pyproject.toml +++ b/projects/jupyter-server-ydoc/pyproject.toml @@ -45,7 +45,8 @@ test = [ "jupyter_server_fileid[test]", "pytest>=7.0", "pytest-cov", - "websockets", + "anyio", + "httpx-ws >=0.5.2", "importlib_metadata >=4.8.3; python_version<'3.10'", ] diff --git a/pyproject.toml b/pyproject.toml index 32e7cecd..d17177ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,6 +106,8 @@ filterwarnings = [ "ignore:.*datetime.utcfromtimestamp\\(\\) is deprecated.*:DeprecationWarning:", # From anyio https://github.com/agronholm/anyio/pull/715 'ignore:Unclosed None: doc = YUnicode() doc.observe(_on_document_change) - async with await rtc_connect_doc_client("text", "file", path) as ws, WebsocketProvider( - doc.ydoc, ws - ): + websocket, room_name = await rtc_connect_doc_client("text", "file", path) + async with websocket as ws, WebsocketProvider(doc.ydoc, Websocket(ws, room_name)): await event.wait() await sleep(0.1) @@ -114,9 +114,8 @@ async def my_listener(logger: EventLogger, schema_id: str, data: dict) -> None: listener=my_listener, ) - async with await rtc_connect_doc_client("text", "file", path) as ws, WebsocketProvider( - doc.ydoc, ws - ): + websocket, room_name = await rtc_connect_doc_client("text", "file", path) + async with websocket as ws, WebsocketProvider(doc.ydoc, Websocket(ws, room_name)): await event.wait() await sleep(0.1) @@ -147,9 +146,8 @@ def _on_document_change(target: str, e: Any) -> None: doc = YUnicode() doc.observe(_on_document_change) - async with await rtc_connect_doc_client("text", "file", path) as ws, WebsocketProvider( - doc.ydoc, ws - ): + websocket, room_name = await rtc_connect_doc_client("text", "file", path) + async with websocket as ws, WebsocketProvider(doc.ydoc, Websocket(ws, room_name)): await event.wait() await sleep(0.1) @@ -173,18 +171,16 @@ async def my_listener(logger: EventLogger, schema_id: str, data: dict) -> None: path2, _ = await rtc_create_file("test2.txt", "test2") try: - async with await rtc_connect_doc_client("text2", "file2", path2) as ws, WebsocketProvider( - doc.ydoc, ws - ): + websocket, room_name = await rtc_connect_doc_client("text2", "file2", path2) + async with websocket as ws, WebsocketProvider(doc.ydoc, Websocket(ws, room_name)): await event.wait() await sleep(0.1) except Exception: pass try: - async with await rtc_connect_doc_client("text2", "file2", path2) as ws, WebsocketProvider( - doc.ydoc, ws - ): + websocket, room_name = await rtc_connect_doc_client("text2", "file2", path2) + async with websocket as ws, WebsocketProvider(doc.ydoc, Websocket(ws, room_name)): await event.wait() await sleep(0.1) except Exception: