Skip to content

Commit

Permalink
Merge pull request #4222 from davep/watcher-context
Browse files Browse the repository at this point in the history
Correct the sender of a reactive-posted message
  • Loading branch information
davep authored Feb 27, 2024
2 parents 201bd5b + 1a1b71c commit 8c519a5
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 11 deletions.
29 changes: 18 additions & 11 deletions src/textual/reactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from . import events
from ._callback import count_parameters
from ._context import active_message_pump
from ._types import (
MessageTarget,
WatchCallbackBothValuesType,
Expand Down Expand Up @@ -73,17 +74,23 @@ def invoke_watcher(
"""
_rich_traceback_omit = True
param_count = count_parameters(watch_function)
if param_count == 2:
watch_result = cast(WatchCallbackBothValuesType, watch_function)(
old_value, value
)
elif param_count == 1:
watch_result = cast(WatchCallbackNewValueType, watch_function)(value)
else:
watch_result = cast(WatchCallbackNoArgsType, watch_function)()
if isawaitable(watch_result):
# Result is awaitable, so we need to await it within an async context
watcher_object.call_next(partial(await_watcher, watcher_object, watch_result))
reset_token = active_message_pump.set(watcher_object)
try:
if param_count == 2:
watch_result = cast(WatchCallbackBothValuesType, watch_function)(
old_value, value
)
elif param_count == 1:
watch_result = cast(WatchCallbackNewValueType, watch_function)(value)
else:
watch_result = cast(WatchCallbackNoArgsType, watch_function)()
if isawaitable(watch_result):
# Result is awaitable, so we need to await it within an async context
watcher_object.call_next(
partial(await_watcher, watcher_object, watch_result)
)
finally:
active_message_pump.reset(reset_token)


@rich.repr.auto
Expand Down
39 changes: 39 additions & 0 deletions tests/test_reactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import pytest

from textual.app import App, ComposeResult
from textual.message import Message
from textual.message_pump import MessagePump
from textual.reactive import Reactive, TooManyComputesError, reactive, var
from textual.widget import Widget

Expand Down Expand Up @@ -705,3 +707,40 @@ def second_callback() -> None:
assert logs == ["first", "second"]
app.query_one(SomeWidget).test_var = 73
assert logs == ["first", "second", "first", "second"]


async def test_message_sender_from_reactive() -> None:
"""Test that the sender of a message comes from the reacting widget."""

message_senders: list[MessagePump | None] = []

class TestWidget(Widget):
test_var: var[int] = var(0, init=False)

class TestMessage(Message):
pass

def watch_test_var(self) -> None:
self.post_message(self.TestMessage())

def make_reaction(self) -> None:
self.test_var += 1

class TestContainer(Widget):
def compose(self) -> ComposeResult:
yield TestWidget()

def on_test_widget_test_message(self, event: TestWidget.TestMessage) -> None:
nonlocal message_senders
message_senders.append(event._sender)

class TestApp(App[None]):

def compose(self) -> ComposeResult:
yield TestContainer()

async with TestApp().run_test() as pilot:
assert message_senders == []
pilot.app.query_one(TestWidget).make_reaction()
await pilot.pause()
assert message_senders == [pilot.app.query_one(TestWidget)]

0 comments on commit 8c519a5

Please sign in to comment.