Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding exception handling for room start tasks #33

Merged
merged 1 commit into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 56 additions & 30 deletions pycrdt_websocket/yroom.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,16 @@ class YRoom:
_update_receive_stream: MemoryObjectReceiveStream
_task_group: TaskGroup | None = None
_started: Event | None = None
_stopped: Event
__start_lock: Lock | None = None
_subscription: Subscription | None = None

def __init__(
self, ready: bool = True, ystore: BaseYStore | None = None, log: Logger | None = None
self,
ready: bool = True,
ystore: BaseYStore | None = None,
exception_handler: Callable[[Exception, Logger], bool] | None = None,
log: Logger | None = None,
):
"""Initialize the object.

Expand All @@ -63,19 +68,20 @@ def __init__(
Arguments:
ready: Whether the internal YDoc is ready to be synchronized right away.
ystore: An optional store in which to persist document updates.
exception_handler: An optional callback to call when an exception is raised, that
returns True if the exception was handled.
log: An optional logger.
"""
self.ydoc = Doc()
self.awareness = Awareness(self.ydoc)
self._update_send_stream, self._update_receive_stream = create_memory_object_stream(
max_buffer_size=65536
)
self.ready_event = Event()
self.ready = ready
self.ystore = ystore
self.log = log or getLogger(__name__)
self.clients = []
self._on_message = None
self.exception_handler = exception_handler
self._stopped = Event()

@property
def _start_lock(self) -> Lock:
Expand Down Expand Up @@ -138,30 +144,42 @@ async def _broadcast_updates(self):
# broadcast internal ydoc's update to all clients, that includes changes from the
# clients and changes from the backend (out-of-band changes)
for client in self.clients:
self.log.debug("Sending Y update to client with endpoint: %s", client.path)
message = create_update_message(update)
self._task_group.start_soon(client.send, message)
try:
self.log.debug("Sending Y update to client with endpoint: %s", client.path)
message = create_update_message(update)
self._task_group.start_soon(client.send, message)
except Exception as exception:
self._handle_exception(exception)
if self.ystore:
self.log.debug("Writing Y update to YStore")
self._task_group.start_soon(self.ystore.write, update)
try:
self._task_group.start_soon(self.ystore.write, update)
self.log.debug("Writing Y update to YStore")
except Exception as exception:
self._handle_exception(exception)
Comment on lines +154 to +158
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that a follow-up to this PR would be to not handle exceptions here, but have a YStore have an optional exception handler, and do the handling there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, @davidbrochart. I've created the following issue to track: #36


async def __aenter__(self) -> YRoom:
async with self._start_lock:
if self._task_group is not None:
raise RuntimeError("YRoom already running")

async with AsyncExitStack() as exit_stack:
tg = create_task_group()
self._task_group = await exit_stack.enter_async_context(tg)
self._task_group = await exit_stack.enter_async_context(create_task_group())
self._exit_stack = exit_stack.pop_all()
await tg.start(partial(self.start, from_context_manager=True))
await self._task_group.start(partial(self.start, from_context_manager=True))

return self

async def __aexit__(self, exc_type, exc_value, exc_tb):
await self.stop()
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

def _handle_exception(self, exception: Exception) -> None:
exception_handled = False
if self.exception_handler is not None:
exception_handled = self.exception_handler(exception, self.log)
if not exception_handled:
raise exception

async def start(
self,
*,
Expand All @@ -177,27 +195,36 @@ async def start(
task_status.started()
self.started.set()
assert self._task_group is not None
self._task_group.start_soon(self._stopped.wait)
self._task_group.start_soon(self._watch_ready)
self._task_group.start_soon(self._broadcast_updates)
return

async with self._start_lock:
if self._task_group is not None:
raise RuntimeError("YRoom already running")

async with create_task_group() as self._task_group:
task_status.started()
self.started.set()
self._task_group.start_soon(self._broadcast_updates)
self._task_group.start_soon(self._watch_ready)
while True:
try:
async with create_task_group() as self._task_group:
if not self.started.is_set():
task_status.started()
self.started.set()
self._update_send_stream, self._update_receive_stream = (
create_memory_object_stream(max_buffer_size=65536)
)
self._task_group.start_soon(self._stopped.wait)
self._task_group.start_soon(self._watch_ready)
self._task_group.start_soon(self._broadcast_updates)
return
except Exception as exception:
self._handle_exception(exception)

async def stop(self) -> None:
"""Stop the room."""
if self._task_group is None:
raise RuntimeError("YRoom not running")

if self._task_group is None:
return

self._stopped.set()
self._task_group.cancel_scope.cancel()
self._task_group = None
if self._subscription is not None:
Expand All @@ -209,10 +236,10 @@ async def serve(self, websocket: Websocket):
Arguments:
websocket: The WebSocket through which to serve the client.
"""
async with create_task_group() as tg:
self.clients.append(websocket)
await sync(self.ydoc, websocket, self.log)
try:
try:
async with create_task_group() as tg:
self.clients.append(websocket)
await sync(self.ydoc, websocket, self.log)
async for message in websocket:
# filter messages (e.g. awareness)
skip = False
Expand Down Expand Up @@ -245,8 +272,7 @@ async def serve(self, websocket: Websocket):
client.path,
)
tg.start_soon(client.send, message)
except Exception as e:
self.log.debug("Error serving endpoint: %s", websocket.path, exc_info=e)

# remove this client
self.clients = [c for c in self.clients if c != websocket]
# remove this client
self.clients = [c for c in self.clients if c != websocket]
except Exception as exception:
self._handle_exception(exception)
33 changes: 33 additions & 0 deletions tests/test_yroom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest
from anyio import TASK_STATUS_IGNORED, sleep
from anyio.abc import TaskStatus
from pycrdt import Map

from pycrdt_websocket import exception_logger
from pycrdt_websocket.yroom import YRoom

pytestmark = pytest.mark.anyio


@pytest.mark.parametrize("websocket_server_api", ["websocket_server_start_stop"], indirect=True)
@pytest.mark.parametrize("yws_server", [{"exception_handler": exception_logger}], indirect=True)
async def test_yroom_restart(yws_server, yws_provider):
port, server = yws_server
yroom = YRoom(exception_handler=exception_logger)

async def raise_error(task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
task_status.started()
raise RuntimeError("foo")

yroom.ydoc = yws_provider
await server.start_room(yroom)
yroom.ydoc["map"] = ymap1 = Map()
ymap1["key"] = "value"
task_group_1 = yroom._task_group
await yroom._task_group.start(raise_error)
ymap1["key2"] = "value2"
Zsailer marked this conversation as resolved.
Show resolved Hide resolved
await sleep(0.1)
assert yroom._task_group is not task_group_1
assert yroom._task_group is not None
assert not yroom._task_group.cancel_scope.cancel_called
await yroom.stop()
Loading