diff --git a/jupyter_collaboration/loaders.py b/jupyter_collaboration/loaders.py index 140e7c82..62cc0a0a 100644 --- a/jupyter_collaboration/loaders.py +++ b/jupyter_collaboration/loaders.py @@ -38,9 +38,7 @@ def __init__( self._contents_manager = contents_manager self._log = log or getLogger(__name__) - self._subscriptions: dict[ - str, Callable[[str, dict[str, Any]], Coroutine[Any, Any, None]] - ] = {} + self._subscriptions: dict[str, Callable[[], Coroutine[Any, Any, None]]] = {} self._watcher = asyncio.create_task(self._watch_file()) if self._poll_interval else None self.last_modified = None @@ -78,11 +76,9 @@ async def clean(self) -> None: self._watcher.cancel() await self._watcher - def observe( - self, id: str, callback: Callable[[str, dict[str, Any]], Coroutine[Any, Any, None]] - ) -> None: + def observe(self, id: str, callback: Callable[[], Coroutine[Any, Any, None]]) -> None: """ - Subscribe to the file to get notified on file changes. + Subscribe to the file to get notified about out-of-band file changes. Parameters: id (str): Room ID @@ -99,7 +95,7 @@ def unobserve(self, id: str) -> None: """ del self._subscriptions[id] - async def load_content(self, format: str, file_type: str, content: bool) -> dict[str, Any]: + async def load_content(self, format: str, file_type: str) -> dict[str, Any]: """ Load the content of the file. @@ -112,31 +108,11 @@ async def load_content(self, format: str, file_type: str, content: bool) -> dict model (dict): A dictionary with the metadata and content of the file. """ async with self._lock: - return await ensure_async( - self._contents_manager.get( - self.path, format=format, type=file_type, content=content - ) + model = await ensure_async( + self._contents_manager.get(self.path, format=format, type=file_type, content=True) ) - - async def save_content(self, model: dict[str, Any]) -> dict[str, Any]: - """ - Save the content of the file. - - Parameters: - model (dict): A dictionary with format, type, last_modified, and content of the file. - - Returns: - model (dict): A dictionary with the metadata and content of the file. - """ - async with self._lock: - path = self.path - if model["type"] not in {"directory", "file", "notebook"}: - # fall back to file if unknown type, the content manager only knows - # how to handle these types - model["type"] = "file" - - self._log.info("Saving file: %s", path) - return await ensure_async(self._contents_manager.save(model, path)) + self.last_modified = model["last_modified"] + return model async def maybe_save_content(self, model: dict[str, Any]) -> None: """ @@ -168,16 +144,24 @@ async def maybe_save_content(self, model: dict[str, Any]) -> None: self._log.info("Saving file: %s", path) # saving is shielded so that it cannot be cancelled # otherwise it could corrupt the file - task = asyncio.create_task(self._save_content(model)) - await asyncio.shield(task) - + done_saving = asyncio.Event() + task = asyncio.create_task(self._save_content(model, done_saving)) + try: + await asyncio.shield(task) + except asyncio.CancelledError: + pass + await done_saving.wait() else: # file changed on disk, raise an error + self.last_modified = m["last_modified"] raise OutOfBandChanges - async def _save_content(self, model: dict[str, Any]) -> None: - m = await ensure_async(self._contents_manager.save(model, self.path)) - self.last_modified = m["last_modified"] + async def _save_content(self, model: dict[str, Any], done_saving: asyncio.Event) -> None: + try: + m = await ensure_async(self._contents_manager.save(model, self.path)) + self.last_modified = m["last_modified"] + finally: + done_saving.set() async def _watch_file(self) -> None: """ @@ -192,24 +176,31 @@ async def _watch_file(self) -> None: try: await asyncio.sleep(self._poll_interval) try: - await self.notify() + await self.maybe_notify() except Exception as e: self._log.error(f"Error watching file: {self.path}\n{e!r}", exc_info=e) except asyncio.CancelledError: break - async def notify(self) -> None: + async def maybe_notify(self) -> None: """ - Notifies subscribed rooms about changes on the content of the file. + Notifies subscribed rooms about out-of-band file changes. """ + do_notify = False async with self._lock: - path = self.path # Get model metadata; format and type are not need - model = await ensure_async(self._contents_manager.get(path, content=False)) + model = await ensure_async(self._contents_manager.get(self.path, content=False)) + + if self.last_modified is not None and self.last_modified < model["last_modified"]: + do_notify = True + + self.last_modified = model["last_modified"] - # Notify that the content changed on disk - for callback in self._subscriptions.values(): - await callback("metadata", model) + if do_notify: + # Notify out-of-band change + # callbacks will load the file content, thus release the lock before calling them + for callback in self._subscriptions.values(): + await callback() class FileLoaderMapping: diff --git a/jupyter_collaboration/rooms.py b/jupyter_collaboration/rooms.py index 8219ddf0..be77356e 100644 --- a/jupyter_collaboration/rooms.py +++ b/jupyter_collaboration/rooms.py @@ -51,7 +51,7 @@ def __init__( # Listen for document changes self._document.observe(self._on_document_change) - self._file.observe(self.room_id, self._on_content_change) + self._file.observe(self.room_id, self._on_outofband_change) @property def room_id(self) -> str: @@ -95,7 +95,7 @@ async def initialize(self) -> None: self.log.info("Initializing room %s", self._room_id) - model = await self._file.load_content(self._file_format, self._file_type, True) + model = await self._file.load_content(self._file_format, self._file_type) async with self._update_lock: # try to apply Y updates from the YStore for this document @@ -144,7 +144,6 @@ async def initialize(self) -> None: if self.ystore: await self.ystore.encode_state_as_update(self.ydoc) - self._file.last_modified = model["last_modified"] self._document.dirty = False self.ready = True self._emit(LogLevel.INFO, "initialize", "Room initialized") @@ -179,32 +178,24 @@ async def _broadcast_updates(self): except asyncio.CancelledError: pass - async def _on_content_change(self, event: str, args: dict[str, Any]) -> None: + async def _on_outofband_change(self) -> None: """ - Called when the file changes. - - Parameters: - event (str): Type of change. - args (dict): A dictionary with format, type, last_modified. + Called when the file got out-of-band changes. """ - if event == "metadata" and ( - self._file.last_modified is None or self._file.last_modified < args["last_modified"] - ): - self.log.info("Out-of-band changes. Overwriting the content in room %s", self._room_id) - self._emit(LogLevel.INFO, "overwrite", "Out-of-band changes. Overwriting the room.") + self.log.info("Out-of-band changes. Overwriting the content in room %s", self._room_id) + self._emit(LogLevel.INFO, "overwrite", "Out-of-band changes. Overwriting the room.") - try: - model = await self._file.load_content(self._file_format, self._file_type, True) - except Exception as e: - msg = f"Error loading content from file: {self._file.path}\n{e!r}" - self.log.error(msg, exc_info=e) - self._emit(LogLevel.ERROR, None, msg) - return None + try: + model = await self._file.load_content(self._file_format, self._file_type) + except Exception as e: + msg = f"Error loading content from file: {self._file.path}\n{e!r}" + self.log.error(msg, exc_info=e) + self._emit(LogLevel.ERROR, None, msg) + return - async with self._update_lock: - self._document.source = model["content"] - self._file.last_modified = model["last_modified"] - self._document.dirty = False + async with self._update_lock: + self._document.source = model["content"] + self._document.dirty = False def _on_document_change(self, target: str, event: Any) -> None: """ @@ -224,14 +215,11 @@ def _on_document_change(self, target: str, event: Any) -> None: if self._update_lock.locked(): return - if self._saving_document is not None and not self._saving_document.done(): - # the document is being saved, cancel that - self._saving_document.cancel() - self._saving_document = None - - self._saving_document = asyncio.create_task(self._maybe_save_document()) + self._saving_document = asyncio.create_task( + self._maybe_save_document(self._saving_document) + ) - async def _maybe_save_document(self) -> None: + async def _maybe_save_document(self, saving_document: asyncio.Task | None) -> None: """ Saves the content of the document to disk. @@ -243,6 +231,11 @@ async def _maybe_save_document(self) -> None: if self._save_delay is None: return + if saving_document is not None and not saving_document.done(): + # the document is being saved, cancel that + saving_document.cancel() + await saving_document + # save after X seconds of inactivity await asyncio.sleep(self._save_delay) @@ -263,7 +256,7 @@ async def _maybe_save_document(self) -> None: except OutOfBandChanges: self.log.info("Out-of-band changes. Overwriting the content in room %s", self._room_id) try: - model = await self._file.load_content(self._file_format, self._file_type, True) + model = await self._file.load_content(self._file_format, self._file_type) except Exception as e: msg = f"Error loading content from file: {self._file.path}\n{e!r}" self.log.error(msg, exc_info=e) @@ -272,7 +265,6 @@ async def _maybe_save_document(self) -> None: async with self._update_lock: self._document.source = model["content"] - self._file.last_modified = model["last_modified"] self._document.dirty = False self._emit(LogLevel.INFO, "overwrite", "Out-of-band changes while saving.") diff --git a/tests/test_loaders.py b/tests/test_loaders.py index 483a2c80..12e543ad 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -4,7 +4,7 @@ from __future__ import annotations import asyncio -from datetime import datetime +from datetime import datetime, timedelta, timezone from jupyter_collaboration.loaders import FileLoader, FileLoaderMapping @@ -17,23 +17,24 @@ async def test_FileLoader_with_watcher(): paths = {} paths[id] = path - cm = FakeContentsManager({"last_modified": datetime.now()}) + cm = FakeContentsManager({"last_modified": datetime.now(timezone.utc)}) loader = FileLoader( id, FakeFileIDManager(paths), cm, poll_interval=0.1, ) + await loader.load_content("text", "file") triggered = False - async def trigger(*args): + async def trigger(): nonlocal triggered triggered = True loader.observe("test", trigger) - cm.model["last_modified"] = datetime.now() + cm.model["last_modified"] = datetime.now(timezone.utc) + timedelta(seconds=1) await asyncio.sleep(0.15) @@ -49,24 +50,25 @@ async def test_FileLoader_without_watcher(): paths = {} paths[id] = path - cm = FakeContentsManager({"last_modified": datetime.now()}) + cm = FakeContentsManager({"last_modified": datetime.now(timezone.utc)}) loader = FileLoader( id, FakeFileIDManager(paths), cm, ) + await loader.load_content("text", "file") triggered = False - async def trigger(*args): + async def trigger(): nonlocal triggered triggered = True loader.observe("test", trigger) - cm.model["last_modified"] = datetime.now() + cm.model["last_modified"] = datetime.now(timezone.utc) + timedelta(seconds=1) - await loader.notify() + await loader.maybe_notify() try: assert triggered @@ -80,7 +82,7 @@ async def test_FileLoaderMapping_with_watcher(): paths = {} paths[id] = path - cm = FakeContentsManager({"last_modified": datetime.now()}) + cm = FakeContentsManager({"last_modified": datetime.now(timezone.utc)}) map = FileLoaderMapping( {"contents_manager": cm, "file_id_manager": FakeFileIDManager(paths)}, @@ -88,10 +90,11 @@ async def test_FileLoaderMapping_with_watcher(): ) loader = map[id] + await loader.load_content("text", "file") triggered = False - async def trigger(*args): + async def trigger(): nonlocal triggered triggered = True @@ -99,7 +102,7 @@ async def trigger(*args): # Clear map (and its loader) before updating => triggered should be False await map.clear() - cm.model["last_modified"] = datetime.now() + cm.model["last_modified"] = datetime.now(timezone.utc) await asyncio.sleep(0.15)