diff --git a/projects/jupyter-server-ydoc/jupyter_server_ydoc/handlers.py b/projects/jupyter-server-ydoc/jupyter_server_ydoc/handlers.py index e06a8204..a5c20a04 100644 --- a/projects/jupyter-server-ydoc/jupyter_server_ydoc/handlers.py +++ b/projects/jupyter-server-ydoc/jupyter_server_ydoc/handlers.py @@ -82,6 +82,21 @@ async def prepare(self): if self._websocket_server.room_exists(self._room_id): self.room: YRoom = await self._websocket_server.get_room(self._room_id) else: + # Logging exceptions, instead of raising them here to ensure + # that the y-rooms stay alive even after an exception is seen. + def exception_logger(exception: Exception, log: Logger) -> bool: + """A function that catches any exceptions raised in the websocket + server and logs them. + + The protects the y-room's task group from cancelling + anytime an exception is raised. + """ + log.error( + f"Document Room Exception, (room_id={self._room_id or 'unknown'}: ", + exc_info=exception, + ) + return True + if self._room_id.count(":") >= 2: # DocumentRoom file_format, file_type, file_id = decode_file_path(self._room_id) @@ -103,33 +118,18 @@ async def prepare(self): self.event_logger, ystore, self.log, - self._document_save_delay, + exception_handler=exception_logger, + save_delay=self._document_save_delay, ) - def exception_logger(exception: Exception, log: Logger) -> bool: - """A function that catches any exceptions raised in the websocket - server and logs them. - - The protects the y-room's task group from cancelling - anytime an exception is raised. - """ - room_id = "unknown" - if self.room.room_id: - room_id = self.room.room_id - log.error( - f"Document Room Exception, (room_id={room_id}: ", - exc_info=exception, - ) - return True - - # Logging exceptions, instead of raising them here to ensure - # that the y-rooms stay alive even after an exception is seen. - self.room.exception_handler = exception_logger - else: # TransientRoom # it is a transient document (e.g. awareness) - self.room = TransientRoom(self._room_id, self.log) + self.room = TransientRoom( + self._room_id, + exception_handler=exception_logger, + log=self.log, + ) await self._websocket_server.start_room(self.room) self._websocket_server.add_room(self._room_id, self.room) diff --git a/projects/jupyter-server-ydoc/jupyter_server_ydoc/rooms.py b/projects/jupyter-server-ydoc/jupyter_server_ydoc/rooms.py index 7b5e5ca1..00787dbe 100644 --- a/projects/jupyter-server-ydoc/jupyter_server_ydoc/rooms.py +++ b/projects/jupyter-server-ydoc/jupyter_server_ydoc/rooms.py @@ -5,7 +5,7 @@ import asyncio from logging import Logger -from typing import Any +from typing import Any, Callable from jupyter_events import EventLogger from jupyter_ydoc import ydocs as YDOCS @@ -31,8 +31,9 @@ def __init__( ystore: BaseYStore | None, log: Logger | None, save_delay: float | None = None, + exception_handler: Callable[[Exception, Logger], bool] | None = None, ): - super().__init__(ready=False, ystore=ystore, log=log) + super().__init__(ready=False, ystore=ystore, exception_handler=exception_handler, log=log) self._room_id: str = room_id self._file_format: str = file_format @@ -285,8 +286,13 @@ async def _maybe_save_document(self, saving_document: asyncio.Task | None) -> No class TransientRoom(YRoom): """A Y room for sharing state (e.g. awareness).""" - def __init__(self, room_id: str, log: Logger | None): - super().__init__(log=log) + def __init__( + self, + room_id: str, + log: Logger | None = None, + exception_handler: Callable[[Exception, Logger], bool] | None = None, + ): + super().__init__(log=log, exception_handler=exception_handler) self._room_id = room_id