Skip to content

Commit

Permalink
Handle last_modified only in FileLoader
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Jan 3, 2024
1 parent 0f98e92 commit 54c99d6
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 91 deletions.
83 changes: 37 additions & 46 deletions jupyter_collaboration/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ def __init__(
self._contents_manager = contents_manager

self._log = log or getLogger(__name__)
self._subscriptions: dict[
str, Callable[[str, dict[str, Any]], Coroutine[Any, Any, None]]
] = {}
self._subscriptions: dict[str, Callable[[], Coroutine[Any, Any, None]]] = {}

self._watcher = asyncio.create_task(self._watch_file()) if self._poll_interval else None
self.last_modified = None
Expand Down Expand Up @@ -78,11 +76,9 @@ async def clean(self) -> None:
self._watcher.cancel()
await self._watcher

def observe(
self, id: str, callback: Callable[[str, dict[str, Any]], Coroutine[Any, Any, None]]
) -> None:
def observe(self, id: str, callback: Callable[[], Coroutine[Any, Any, None]]) -> None:
"""
Subscribe to the file to get notified on file changes.
Subscribe to the file to get notified about out-of-band file changes.
Parameters:
id (str): Room ID
Expand All @@ -99,7 +95,7 @@ def unobserve(self, id: str) -> None:
"""
del self._subscriptions[id]

async def load_content(self, format: str, file_type: str, content: bool) -> dict[str, Any]:
async def load_content(self, format: str, file_type: str) -> dict[str, Any]:
"""
Load the content of the file.
Expand All @@ -112,31 +108,11 @@ async def load_content(self, format: str, file_type: str, content: bool) -> dict
model (dict): A dictionary with the metadata and content of the file.
"""
async with self._lock:
return await ensure_async(
self._contents_manager.get(
self.path, format=format, type=file_type, content=content
)
model = await ensure_async(
self._contents_manager.get(self.path, format=format, type=file_type, content=True)
)

async def save_content(self, model: dict[str, Any]) -> dict[str, Any]:
"""
Save the content of the file.
Parameters:
model (dict): A dictionary with format, type, last_modified, and content of the file.
Returns:
model (dict): A dictionary with the metadata and content of the file.
"""
async with self._lock:
path = self.path
if model["type"] not in {"directory", "file", "notebook"}:
# fall back to file if unknown type, the content manager only knows
# how to handle these types
model["type"] = "file"

self._log.info("Saving file: %s", path)
return await ensure_async(self._contents_manager.save(model, path))
self.last_modified = model["last_modified"]
return model

async def maybe_save_content(self, model: dict[str, Any]) -> None:
"""
Expand Down Expand Up @@ -168,16 +144,24 @@ async def maybe_save_content(self, model: dict[str, Any]) -> None:
self._log.info("Saving file: %s", path)
# saving is shielded so that it cannot be cancelled
# otherwise it could corrupt the file
task = asyncio.create_task(self._save_content(model))
await asyncio.shield(task)

done_saving = asyncio.Event()
task = asyncio.create_task(self._save_content(model, done_saving))
try:
await asyncio.shield(task)
except asyncio.CancelledError:
pass
await done_saving.wait()
else:
# file changed on disk, raise an error
self.last_modified = m["last_modified"]
raise OutOfBandChanges

async def _save_content(self, model: dict[str, Any]) -> None:
m = await ensure_async(self._contents_manager.save(model, self.path))
self.last_modified = m["last_modified"]
async def _save_content(self, model: dict[str, Any], done_saving: asyncio.Event) -> None:
try:
m = await ensure_async(self._contents_manager.save(model, self.path))
self.last_modified = m["last_modified"]
finally:
done_saving.set()

async def _watch_file(self) -> None:
"""
Expand All @@ -192,24 +176,31 @@ async def _watch_file(self) -> None:
try:
await asyncio.sleep(self._poll_interval)
try:
await self.notify()
await self.maybe_notify()
except Exception as e:
self._log.error(f"Error watching file: {self.path}\n{e!r}", exc_info=e)
except asyncio.CancelledError:
break

async def notify(self) -> None:
async def maybe_notify(self) -> None:
"""
Notifies subscribed rooms about changes on the content of the file.
Notifies subscribed rooms about out-of-band file changes.
"""
do_notify = False
async with self._lock:
path = self.path
# Get model metadata; format and type are not need
model = await ensure_async(self._contents_manager.get(path, content=False))
model = await ensure_async(self._contents_manager.get(self.path, content=False))

if self.last_modified is not None and self.last_modified < model["last_modified"]:
do_notify = True

self.last_modified = model["last_modified"]

# Notify that the content changed on disk
for callback in self._subscriptions.values():
await callback("metadata", model)
if do_notify:
# Notify out-of-band change
# callbacks will load the file content, thus release the lock before calling them
for callback in self._subscriptions.values():
await callback()


class FileLoaderMapping:
Expand Down
60 changes: 26 additions & 34 deletions jupyter_collaboration/rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(

# Listen for document changes
self._document.observe(self._on_document_change)
self._file.observe(self.room_id, self._on_content_change)
self._file.observe(self.room_id, self._on_outofband_change)

@property
def room_id(self) -> str:
Expand Down Expand Up @@ -95,7 +95,7 @@ async def initialize(self) -> None:

self.log.info("Initializing room %s", self._room_id)

model = await self._file.load_content(self._file_format, self._file_type, True)
model = await self._file.load_content(self._file_format, self._file_type)

async with self._update_lock:
# try to apply Y updates from the YStore for this document
Expand Down Expand Up @@ -144,7 +144,6 @@ async def initialize(self) -> None:
if self.ystore:
await self.ystore.encode_state_as_update(self.ydoc)

self._file.last_modified = model["last_modified"]
self._document.dirty = False
self.ready = True
self._emit(LogLevel.INFO, "initialize", "Room initialized")
Expand Down Expand Up @@ -179,32 +178,24 @@ async def _broadcast_updates(self):
except asyncio.CancelledError:
pass

async def _on_content_change(self, event: str, args: dict[str, Any]) -> None:
async def _on_outofband_change(self) -> None:
"""
Called when the file changes.
Parameters:
event (str): Type of change.
args (dict): A dictionary with format, type, last_modified.
Called when the file got out-of-band changes.
"""
if event == "metadata" and (
self._file.last_modified is None or self._file.last_modified < args["last_modified"]
):
self.log.info("Out-of-band changes. Overwriting the content in room %s", self._room_id)
self._emit(LogLevel.INFO, "overwrite", "Out-of-band changes. Overwriting the room.")
self.log.info("Out-of-band changes. Overwriting the content in room %s", self._room_id)
self._emit(LogLevel.INFO, "overwrite", "Out-of-band changes. Overwriting the room.")

try:
model = await self._file.load_content(self._file_format, self._file_type, True)
except Exception as e:
msg = f"Error loading content from file: {self._file.path}\n{e!r}"
self.log.error(msg, exc_info=e)
self._emit(LogLevel.ERROR, None, msg)
return None
try:
model = await self._file.load_content(self._file_format, self._file_type)
except Exception as e:
msg = f"Error loading content from file: {self._file.path}\n{e!r}"
self.log.error(msg, exc_info=e)
self._emit(LogLevel.ERROR, None, msg)
return

async with self._update_lock:
self._document.source = model["content"]
self._file.last_modified = model["last_modified"]
self._document.dirty = False
async with self._update_lock:
self._document.source = model["content"]
self._document.dirty = False

def _on_document_change(self, target: str, event: Any) -> None:
"""
Expand All @@ -224,14 +215,11 @@ def _on_document_change(self, target: str, event: Any) -> None:
if self._update_lock.locked():
return

if self._saving_document is not None and not self._saving_document.done():
# the document is being saved, cancel that
self._saving_document.cancel()
self._saving_document = None

self._saving_document = asyncio.create_task(self._maybe_save_document())
self._saving_document = asyncio.create_task(
self._maybe_save_document(self._saving_document)
)

async def _maybe_save_document(self) -> None:
async def _maybe_save_document(self, saving_document: asyncio.Task | None) -> None:
"""
Saves the content of the document to disk.
Expand All @@ -243,6 +231,11 @@ async def _maybe_save_document(self) -> None:
if self._save_delay is None:
return

if saving_document is not None and not saving_document.done():
# the document is being saved, cancel that
saving_document.cancel()
await saving_document

# save after X seconds of inactivity
await asyncio.sleep(self._save_delay)

Expand All @@ -263,7 +256,7 @@ async def _maybe_save_document(self) -> None:
except OutOfBandChanges:
self.log.info("Out-of-band changes. Overwriting the content in room %s", self._room_id)
try:
model = await self._file.load_content(self._file_format, self._file_type, True)
model = await self._file.load_content(self._file_format, self._file_type)
except Exception as e:
msg = f"Error loading content from file: {self._file.path}\n{e!r}"
self.log.error(msg, exc_info=e)
Expand All @@ -272,7 +265,6 @@ async def _maybe_save_document(self) -> None:

async with self._update_lock:
self._document.source = model["content"]
self._file.last_modified = model["last_modified"]
self._document.dirty = False

self._emit(LogLevel.INFO, "overwrite", "Out-of-band changes while saving.")
Expand Down
25 changes: 14 additions & 11 deletions tests/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

import asyncio
from datetime import datetime
from datetime import datetime, timedelta, timezone

from jupyter_collaboration.loaders import FileLoader, FileLoaderMapping

Expand All @@ -17,23 +17,24 @@ async def test_FileLoader_with_watcher():
paths = {}
paths[id] = path

cm = FakeContentsManager({"last_modified": datetime.now()})
cm = FakeContentsManager({"last_modified": datetime.now(timezone.utc)})
loader = FileLoader(
id,
FakeFileIDManager(paths),
cm,
poll_interval=0.1,
)
await loader.load_content("text", "file")

triggered = False

async def trigger(*args):
async def trigger():
nonlocal triggered
triggered = True

loader.observe("test", trigger)

cm.model["last_modified"] = datetime.now()
cm.model["last_modified"] = datetime.now(timezone.utc) + timedelta(seconds=1)

await asyncio.sleep(0.15)

Expand All @@ -49,24 +50,25 @@ async def test_FileLoader_without_watcher():
paths = {}
paths[id] = path

cm = FakeContentsManager({"last_modified": datetime.now()})
cm = FakeContentsManager({"last_modified": datetime.now(timezone.utc)})
loader = FileLoader(
id,
FakeFileIDManager(paths),
cm,
)
await loader.load_content("text", "file")

triggered = False

async def trigger(*args):
async def trigger():
nonlocal triggered
triggered = True

loader.observe("test", trigger)

cm.model["last_modified"] = datetime.now()
cm.model["last_modified"] = datetime.now(timezone.utc) + timedelta(seconds=1)

await loader.notify()
await loader.maybe_notify()

try:
assert triggered
Expand All @@ -80,26 +82,27 @@ async def test_FileLoaderMapping_with_watcher():
paths = {}
paths[id] = path

cm = FakeContentsManager({"last_modified": datetime.now()})
cm = FakeContentsManager({"last_modified": datetime.now(timezone.utc)})

map = FileLoaderMapping(
{"contents_manager": cm, "file_id_manager": FakeFileIDManager(paths)},
file_poll_interval=1.0,
)

loader = map[id]
await loader.load_content("text", "file")

triggered = False

async def trigger(*args):
async def trigger():
nonlocal triggered
triggered = True

loader.observe("test", trigger)

# Clear map (and its loader) before updating => triggered should be False
await map.clear()
cm.model["last_modified"] = datetime.now()
cm.model["last_modified"] = datetime.now(timezone.utc)

await asyncio.sleep(0.15)

Expand Down

0 comments on commit 54c99d6

Please sign in to comment.