Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix file saving #231

Merged
merged 1 commit into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions jupyter_collaboration/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
] = {}

self._watcher = asyncio.create_task(self._watch_file()) if self._poll_interval else None
self.last_modified = None

@property
def file_id(self) -> str:
Expand Down Expand Up @@ -137,16 +138,13 @@ async def save_content(self, model: dict[str, Any]) -> dict[str, Any]:
self._log.info("Saving file: %s", path)
return await ensure_async(self._contents_manager.save(model, path))

async def maybe_save_content(self, model: dict[str, Any]) -> dict[str, Any]:
async def maybe_save_content(self, model: dict[str, Any]) -> None:
"""
Save the content of the file.

Parameters:
model (dict): A dictionary with format, type, last_modified, and content of the file.

Returns:
model (dict): A dictionary with the metadata and content of the file.

Raises:
OutOfBandChanges: if the file was modified at a latter time than the model

Expand All @@ -166,14 +164,21 @@ async def maybe_save_content(self, model: dict[str, Any]) -> dict[str, Any]:
)
)

if model["last_modified"] == m["last_modified"]:
if self.last_modified == m["last_modified"]:
self._log.info("Saving file: %s", path)
return await ensure_async(self._contents_manager.save(model, path))
# saving is shielded so that it cannot be cancelled
# otherwise it could corrupt the file
task = asyncio.create_task(self._save_content(model))
await asyncio.shield(task)

else:
# file changed on disk, raise an error
raise OutOfBandChanges

async def _save_content(self, model: dict[str, Any]) -> None:
m = await ensure_async(self._contents_manager.save(model, self.path))
self.last_modified = m["last_modified"]

async def _watch_file(self) -> None:
"""
Async task for watching a file.
Expand Down
13 changes: 5 additions & 8 deletions jupyter_collaboration/rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def __init__(
self._room_id: str = room_id
self._file_format: str = file_format
self._file_type: str = file_type
self._last_modified: Any = None
self._file: FileLoader = file
self._document = YDOCS.get(self._file_type, YFILE)(self.ydoc)

Expand Down Expand Up @@ -145,7 +144,7 @@ async def initialize(self) -> None:
if self.ystore:
await self.ystore.encode_state_as_update(self.ydoc)

self._last_modified = model["last_modified"]
self._file.last_modified = model["last_modified"]
self._document.dirty = False
self.ready = True
self._emit(LogLevel.INFO, "initialize", "Room initialized")
Expand Down Expand Up @@ -189,7 +188,7 @@ async def _on_content_change(self, event: str, args: dict[str, Any]) -> None:
args (dict): A dictionary with format, type, last_modified.
"""
if event == "metadata" and (
self._last_modified is None or self._last_modified < args["last_modified"]
self._file.last_modified is None or self._file.last_modified < args["last_modified"]
):
self.log.info("Out-of-band changes. Overwriting the content in room %s", self._room_id)
self._emit(LogLevel.INFO, "overwrite", "Out-of-band changes. Overwriting the room.")
Expand All @@ -204,7 +203,7 @@ async def _on_content_change(self, event: str, args: dict[str, Any]) -> None:

async with self._update_lock:
self._document.source = model["content"]
self._last_modified = model["last_modified"]
self._file.last_modified = model["last_modified"]
self._document.dirty = False

def _on_document_change(self, target: str, event: Any) -> None:
Expand Down Expand Up @@ -249,15 +248,13 @@ async def _maybe_save_document(self) -> None:

try:
self.log.info("Saving the content from room %s", self._room_id)
model = await self._file.maybe_save_content(
await self._file.maybe_save_content(
{
"format": self._file_format,
"type": self._file_type,
"last_modified": self._last_modified,
"content": self._document.source,
}
)
self._last_modified = model["last_modified"]
async with self._update_lock:
self._document.dirty = False

Expand All @@ -275,7 +272,7 @@ async def _maybe_save_document(self) -> None:

async with self._update_lock:
self._document.source = model["content"]
self._last_modified = model["last_modified"]
self._file.last_modified = model["last_modified"]
self._document.dirty = False

self._emit(LogLevel.INFO, "overwrite", "Out-of-band changes while saving.")
Expand Down
Loading