Skip to content

Commit

Permalink
rlock tests
Browse files Browse the repository at this point in the history
  • Loading branch information
willmcgugan committed Jun 6, 2024
1 parent 7925273 commit b10d24d
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 3 deletions.
5 changes: 3 additions & 2 deletions src/textual/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
from .notifications import Notification, Notifications, Notify, SeverityLevel
from .reactive import Reactive
from .renderables.blank import Blank
from .rlock import RLock
from .screen import (
ActiveBinding,
Screen,
Expand Down Expand Up @@ -579,7 +580,7 @@ def __init__(
else None
)
self._screenshot: str | None = None
self._dom_lock = asyncio.Lock()
self._dom_lock = RLock()
self._dom_ready = False
self._batch_count = 0
self._notifications = Notifications()
Expand Down Expand Up @@ -3555,7 +3556,7 @@ def _refresh_notifications(self) -> None:
# or one will turn up. Things will work out later.
return
# Update the toast rack.
toast_rack.show(self._notifications)
self.call_later(toast_rack.show, self._notifications)

def notify(
self,
Expand Down
61 changes: 61 additions & 0 deletions src/textual/rlock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations

from asyncio import Lock, Task, current_task


class RLock:
"""A re-entrant asyncio lock."""

def __init__(self) -> None:
self._owner: Task | None = None
self._count = 0
self._lock = Lock()

async def acquire(self) -> None:
"""Wait until the lock can be acquired."""
task = current_task()
assert task is not None
if self._owner is None or self._owner is not task:
await self._lock.acquire()
self._owner = task
self._count += 1

def release(self) -> None:
"""Release a previously acquired lock."""
task = current_task()
assert task is not None
self._count -= 1
if self._count < 0:
# Should not occur if every acquire as a release
raise RuntimeError("RLock.release called too many times")
if self._owner is task:
if not self._count:
self._owner = None
self._lock.release()

@property
def is_locked(self):
"""Return True if lock is acquired."""
return self._lock.locked()

async def __aenter__(self) -> None:
"""Asynchronous context manager to acquire and release lock."""
await self.acquire()

async def __aexit__(self, _type, _value, _traceback) -> None:
"""Exit the context manager."""
self.release()


if __name__ == "__main__":
from asyncio import Lock

async def locks():
lock = RLock()
async with lock:
async with lock:
print("Hello")

import asyncio

asyncio.run(locks())
1 change: 0 additions & 1 deletion src/textual/widgets/_toast.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ def show(self, notifications: Notifications) -> None:
Args:
notifications: The notifications to show.
"""

# Look for any stale toasts and remove them.
for toast in self.query(Toast):
if toast._notification not in notifications:
Expand Down
56 changes: 56 additions & 0 deletions tests/test_rlock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import asyncio

import pytest

from textual.rlock import RLock


async def test_simple_lock():
lock = RLock()
# Starts not locked
assert not lock.is_locked
# Acquire the lock
await lock.acquire()
assert lock.is_locked
# Acquire a second time (should not block)
await lock.acquire()
assert lock.is_locked

# Release the lock
lock.release()
# Should still be locked
assert lock.is_locked
# Release the lock
lock.release()
# Should be released
assert not lock.is_locked

# Another release is a runtime error
with pytest.raises(RuntimeError):
lock.release()


async def test_multiple_tasks() -> None:
"""Check RLock prevents other tasks from acquiring lock."""
lock = RLock()

started: list[int] = []
done: list[int] = []

async def test_task(n: int) -> None:
started.append(n)
async with lock:
done.append(n)

async with lock:
assert done == []
task1 = asyncio.create_task(test_task(1))
assert sorted(started) == []
task2 = asyncio.create_task(test_task(2))
await asyncio.sleep(0)
assert sorted(started) == [1, 2]

await task1
assert 1 in done
await task2
assert 2 in done

0 comments on commit b10d24d

Please sign in to comment.