diff --git a/src/textual/app.py b/src/textual/app.py index ccbe85d3cc..07b5e4f1a3 100644 --- a/src/textual/app.py +++ b/src/textual/app.py @@ -103,6 +103,7 @@ from .notifications import Notification, Notifications, Notify, SeverityLevel from .reactive import Reactive from .renderables.blank import Blank +from .rlock import RLock from .screen import ( ActiveBinding, Screen, @@ -579,7 +580,7 @@ def __init__( else None ) self._screenshot: str | None = None - self._dom_lock = asyncio.Lock() + self._dom_lock = RLock() self._dom_ready = False self._batch_count = 0 self._notifications = Notifications() @@ -3555,7 +3556,7 @@ def _refresh_notifications(self) -> None: # or one will turn up. Things will work out later. return # Update the toast rack. - toast_rack.show(self._notifications) + self.call_later(toast_rack.show, self._notifications) def notify( self, diff --git a/src/textual/rlock.py b/src/textual/rlock.py new file mode 100644 index 0000000000..d7a6af2d5e --- /dev/null +++ b/src/textual/rlock.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from asyncio import Lock, Task, current_task + + +class RLock: + """A re-entrant asyncio lock.""" + + def __init__(self) -> None: + self._owner: Task | None = None + self._count = 0 + self._lock = Lock() + + async def acquire(self) -> None: + """Wait until the lock can be acquired.""" + task = current_task() + assert task is not None + if self._owner is None or self._owner is not task: + await self._lock.acquire() + self._owner = task + self._count += 1 + + def release(self) -> None: + """Release a previously acquired lock.""" + task = current_task() + assert task is not None + self._count -= 1 + if self._count < 0: + # Should not occur if every acquire as a release + raise RuntimeError("RLock.release called too many times") + if self._owner is task: + if not self._count: + self._owner = None + self._lock.release() + + @property + def is_locked(self): + """Return True if lock is acquired.""" + return self._lock.locked() + + async def __aenter__(self) -> None: + """Asynchronous context manager to acquire and release lock.""" + await self.acquire() + + async def __aexit__(self, _type, _value, _traceback) -> None: + """Exit the context manager.""" + self.release() + + +if __name__ == "__main__": + from asyncio import Lock + + async def locks(): + lock = RLock() + async with lock: + async with lock: + print("Hello") + + import asyncio + + asyncio.run(locks()) diff --git a/src/textual/widgets/_toast.py b/src/textual/widgets/_toast.py index a8198f4b53..0a057523f7 100644 --- a/src/textual/widgets/_toast.py +++ b/src/textual/widgets/_toast.py @@ -183,7 +183,6 @@ def show(self, notifications: Notifications) -> None: Args: notifications: The notifications to show. """ - # Look for any stale toasts and remove them. for toast in self.query(Toast): if toast._notification not in notifications: diff --git a/tests/test_rlock.py b/tests/test_rlock.py new file mode 100644 index 0000000000..40cc4281b2 --- /dev/null +++ b/tests/test_rlock.py @@ -0,0 +1,56 @@ +import asyncio + +import pytest + +from textual.rlock import RLock + + +async def test_simple_lock(): + lock = RLock() + # Starts not locked + assert not lock.is_locked + # Acquire the lock + await lock.acquire() + assert lock.is_locked + # Acquire a second time (should not block) + await lock.acquire() + assert lock.is_locked + + # Release the lock + lock.release() + # Should still be locked + assert lock.is_locked + # Release the lock + lock.release() + # Should be released + assert not lock.is_locked + + # Another release is a runtime error + with pytest.raises(RuntimeError): + lock.release() + + +async def test_multiple_tasks() -> None: + """Check RLock prevents other tasks from acquiring lock.""" + lock = RLock() + + started: list[int] = [] + done: list[int] = [] + + async def test_task(n: int) -> None: + started.append(n) + async with lock: + done.append(n) + + async with lock: + assert done == [] + task1 = asyncio.create_task(test_task(1)) + assert sorted(started) == [] + task2 = asyncio.create_task(test_task(2)) + await asyncio.sleep(0) + assert sorted(started) == [1, 2] + + await task1 + assert 1 in done + await task2 + assert 2 in done