Skip to content

Commit

Permalink
signal arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
willmcgugan committed Apr 22, 2024
1 parent 1de74dc commit 8a9746d
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 24 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

- Fixed `TextArea` to end mouse selection only if currently selecting https://github.com/Textualize/textual/pull/4436

### Changed

- Added argument to signal callbacks

## [0.57.1] - 2024-04-20

### Fixed
Expand Down
8 changes: 4 additions & 4 deletions src/textual/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,15 +603,15 @@ def __init__(
self._original_stderr = sys.__stderr__
"""The original stderr stream (before redirection etc)."""

self.app_suspend_signal = Signal(self, "app-suspend")
self.app_suspend_signal: Signal[App] = Signal(self, "app-suspend")
"""The signal that is published when the app is suspended.
When [`App.suspend`][textual.app.App.suspend] is called this signal
will be [published][textual.signal.Signal.publish];
[subscribe][textual.signal.Signal.subscribe] to this signal to
perform work before the suspension takes place.
"""
self.app_resume_signal = Signal(self, "app-resume")
self.app_resume_signal: Signal[App] = Signal(self, "app-resume")
"""The signal that is published when the app is resumed after a suspend.
When the app is resumed after a
Expand Down Expand Up @@ -3569,12 +3569,12 @@ def action_command_palette(self) -> None:

def _suspend_signal(self) -> None:
"""Signal that the application is being suspended."""
self.app_suspend_signal.publish()
self.app_suspend_signal.publish(self)

@on(Driver.SignalResume)
def _resume_signal(self) -> None:
"""Signal that the application is being resumed from a suspension."""
self.app_resume_signal.publish()
self.app_resume_signal.publish(self)

@contextmanager
def suspend(self) -> Iterator[None]:
Expand Down
8 changes: 5 additions & 3 deletions src/textual/screen.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,9 @@ def __init__(
self.title = self.TITLE
self.sub_title = self.SUB_TITLE

self.screen_layout_refresh_signal = Signal(self, "layout-refresh")
self.screen_layout_refresh_signal: Signal[Screen] = Signal(
self, "layout-refresh"
)
"""The signal that is published when the screen's layout is refreshed."""

@property
Expand Down Expand Up @@ -861,7 +863,7 @@ def _refresh_layout(self, size: Size | None = None, scroll: bool = False) -> Non
self._compositor_refresh()

if self.app._dom_ready:
self.screen_layout_refresh_signal.publish()
self.screen_layout_refresh_signal.publish(self.screen)
else:
self.app.post_message(events.Ready())
self.app._dom_ready = True
Expand Down Expand Up @@ -966,7 +968,7 @@ def _clear_tooltip(self) -> None:
self._tooltip_timer.stop()
tooltip.display = False

def _maybe_clear_tooltip(self) -> None:
def _maybe_clear_tooltip(self, _) -> None:
"""Check if the widget under the mouse cursor still pertains to the tooltip.
If they differ, the tooltip will be removed.
Expand Down
32 changes: 21 additions & 11 deletions src/textual/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,29 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Generic, TypeVar, Union
from weakref import WeakKeyDictionary

import rich.repr

from textual import log

if TYPE_CHECKING:
from ._types import IgnoreReturnCallbackType
from .dom import DOMNode

SignalT = TypeVar("SignalT")

SignalCallbackType = Union[
Callable[[SignalT], Awaitable[Any]], Callable[[SignalT], Any]
]


class SignalError(Exception):
"""Raised for Signal errors."""


@rich.repr.auto(angular=True)
class Signal:
class Signal(Generic[SignalT]):
"""A signal that a widget may subscribe to, in order to invoke callbacks when an associated event occurs."""

def __init__(self, owner: DOMNode, name: str) -> None:
Expand All @@ -38,23 +43,23 @@ def __init__(self, owner: DOMNode, name: str) -> None:
"""
self._owner = owner
self._name = name
self._subscriptions: WeakKeyDictionary[
DOMNode, list[IgnoreReturnCallbackType]
] = WeakKeyDictionary()
self._subscriptions: WeakKeyDictionary[DOMNode, list[SignalCallbackType]] = (
WeakKeyDictionary()
)

def __rich_repr__(self) -> rich.repr.Result:
yield "owner", self._owner
yield "name", self._name
yield "subscriptions", list(self._subscriptions.keys())

def subscribe(self, node: DOMNode, callback: IgnoreReturnCallbackType) -> None:
def subscribe(self, node: DOMNode, callback: SignalCallbackType) -> None:
"""Subscribe a node to this signal.
When the signal is published, the callback will be invoked.
Args:
node: Node to subscribe.
callback: A callback function which takes no arguments, and returns anything (return type ignored).
callback: A callback function which takes a single argument and returns anything (return type ignored).
Raises:
SignalError: Raised when subscribing a non-mounted widget.
Expand All @@ -75,8 +80,13 @@ def unsubscribe(self, node: DOMNode) -> None:
"""
self._subscriptions.pop(node, None)

def publish(self) -> None:
"""Publish the signal (invoke subscribed callbacks)."""
def publish(self, data: SignalT) -> None:
"""Publish the signal (invoke subscribed callbacks).
Args:
data: An argument to pass to the callbacks.
"""

for node, callbacks in list(self._subscriptions.items()):
if not node.is_running:
Expand All @@ -86,7 +96,7 @@ def publish(self) -> None:
# Call callbacks
for callback in callbacks:
try:
callback()
callback(data)
except Exception as error:
log.error(
f"error publishing signal to {node} ignored (callback={callback}); {error}"
Expand Down
43 changes: 39 additions & 4 deletions tests/test_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ async def test_signal():

class TestLabel(Label):
def on_mount(self) -> None:
def signal_result():
def signal_result(_):
nonlocal called
called += 1

Expand All @@ -22,14 +22,14 @@ class TestApp(App):
BINDINGS = [("space", "signal")]

def __init__(self) -> None:
self.test_signal = Signal(self, "coffee ready")
self.test_signal: Signal[str] = Signal(self, "coffee ready")
super().__init__()

def compose(self) -> ComposeResult:
yield TestLabel()

def action_signal(self) -> None:
self.test_signal.publish()
self.test_signal.publish("foo")

app = TestApp()
async with app.run_test() as pilot:
Expand Down Expand Up @@ -65,11 +65,46 @@ def test_signal_errors():
label = Label()
# Check subscribing a non-running widget is an error
with pytest.raises(SignalError):
test_signal.subscribe(label, lambda: None)
test_signal.subscribe(label, lambda _: None)


def test_repr():
"""Check the repr doesn't break."""
app = App()
test_signal = Signal(app, "test")
assert isinstance(repr(test_signal), str)


async def test_signal_parameters():
str_result: str | None = None
int_result: int | None = None

class TestApp(App):
BINDINGS = [("space", "signal")]

def __init__(self) -> None:
self.str_signal: Signal[str] = Signal(self, "str")
self.int_signal: Signal[int] = Signal(self, "int")
super().__init__()

def action_signal(self) -> None:
self.str_signal.publish("foo")
self.int_signal.publish(3)

def on_mount(self) -> None:
def on_str(my_str):
nonlocal str_result
str_result = my_str

def on_int(my_int):
nonlocal int_result
int_result = my_int

self.str_signal.subscribe(self, on_str)
self.int_signal.subscribe(self, on_int)

app = TestApp()
async with app.run_test() as pilot:
await pilot.press("space")
assert str_result == "foo"
assert int_result == 3
4 changes: 2 additions & 2 deletions tests/test_suspend.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def resume_application_mode(self) -> None:
calls.add("resume")

class SuspendApp(App[None]):
def on_suspend(self) -> None:
def on_suspend(self, _) -> None:
nonlocal calls
calls.add("suspend signal")

def on_resume(self) -> None:
def on_resume(self, _) -> None:
nonlocal calls
calls.add("resume signal")

Expand Down

0 comments on commit 8a9746d

Please sign in to comment.