diff --git a/jupyter_collaboration/loaders.py b/jupyter_collaboration/loaders.py index a2622cca..1d0e07d0 100644 --- a/jupyter_collaboration/loaders.py +++ b/jupyter_collaboration/loaders.py @@ -39,9 +39,11 @@ def __init__( self._log = log or getLogger(__name__) self._subscriptions: dict[str, Callable[[], Coroutine[Any, Any, None]]] = {} + self._filepath_subscriptions: dict[str, Callable[[], Coroutine[Any, Any, None] | None]] = {} self._watcher = asyncio.create_task(self._watch_file()) if self._poll_interval else None self.last_modified = None + self._current_path = self.path @property def file_id(self) -> str: @@ -79,7 +81,12 @@ async def clean(self) -> None: except asyncio.CancelledError: self._log.info(f"file watcher for '{self.file_id}' is cancelled now") - def observe(self, id: str, callback: Callable[[], Coroutine[Any, Any, None]]) -> None: + def observe( + self, + id: str, + callback: Callable[[], Coroutine[Any, Any, None]], + filepath_callback: Callable[[], Coroutine[Any, Any, None] | None] | None = None, + ) -> None: """ Subscribe to the file to get notified about out-of-band file changes. @@ -88,6 +95,8 @@ def observe(self, id: str, callback: Callable[[], Coroutine[Any, Any, None]]) -> callback (Callable): Callback for notifying the room. """ self._subscriptions[id] = callback + if filepath_callback is not None: + self._filepath_subscriptions[id] = filepath_callback def unobserve(self, id: str) -> None: """ @@ -97,6 +106,8 @@ def unobserve(self, id: str) -> None: id (str): Room ID """ del self._subscriptions[id] + if id in self._filepath_subscriptions: + del self._filepath_subscriptions[id] async def load_content(self, format: str, file_type: str) -> dict[str, Any]: """ @@ -190,15 +201,26 @@ async def maybe_notify(self) -> None: Notifies subscribed rooms about out-of-band file changes. """ do_notify = False + filepath_change = False async with self._lock: + path = self.path + if self._current_path != path: + self._current_path = path + filepath_change = True + # Get model metadata; format and type are not need - model = await ensure_async(self._contents_manager.get(self.path, content=False)) + model = await ensure_async(self._contents_manager.get(path, content=False)) if self.last_modified is not None and self.last_modified < model["last_modified"]: do_notify = True self.last_modified = model["last_modified"] + if filepath_change: + # Notify filepath change + for callback in self._filepath_subscriptions.values(): + await ensure_async(callback()) + if do_notify: # Notify out-of-band change # callbacks will load the file content, thus release the lock before calling them diff --git a/jupyter_collaboration/rooms.py b/jupyter_collaboration/rooms.py index d2b0e28c..6cbfded2 100644 --- a/jupyter_collaboration/rooms.py +++ b/jupyter_collaboration/rooms.py @@ -42,6 +42,7 @@ def __init__( self._file_type: str = file_type self._file: FileLoader = file self._document = YDOCS.get(self._file_type, YFILE)(self.ydoc, self.awareness) + self._document.path = self._file.path self._logger = logger self._save_delay = save_delay @@ -54,7 +55,7 @@ def __init__( # Listen for document changes self._document.observe(self._on_document_change) - self._file.observe(self.room_id, self._on_outofband_change) + self._file.observe(self.room_id, self._on_outofband_change, self._on_filepath_change) @property def room_id(self) -> str: @@ -211,6 +212,12 @@ async def _on_outofband_change(self) -> None: self._document.source = model["content"] self._document.dirty = False + def _on_filepath_change(self) -> None: + """ + Update the document path property. + """ + self._document.path = self._file.path + def _on_document_change(self, target: str, event: Any) -> None: """ Called when the shared document changes. diff --git a/tests/test_rooms.py b/tests/test_rooms.py index b3f58872..2bdfa04a 100644 --- a/tests/test_rooms.py +++ b/tests/test_rooms.py @@ -75,3 +75,22 @@ async def test_undefined_save_delay_should_not_save_content_after_document_chang await asyncio.sleep(0.15) assert "save" not in cm.actions + + +async def test_document_path(rtc_create_mock_document_room): + id = "test-id" + path = "test.txt" + new_path = "test2.txt" + + _, loader, room = rtc_create_mock_document_room(id, path, "") + + await room.initialize() + assert room._document.path == path + + # Update the path + loader._file_id_manager.move(id, new_path) + + # Wait for a bit more than the poll_interval + await asyncio.sleep(0.15) + + assert room._document.path == new_path diff --git a/tests/utils.py b/tests/utils.py index 8114b673..f1140322 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -16,6 +16,9 @@ def __init__(self, mapping: dict): def get_path(self, id: str) -> str: return self.mapping[id] + def move(self, id: str, new_path: str) -> None: + self.mapping[id] = new_path + class FakeContentsManager: def __init__(self, model: dict):