diff --git a/src/textual/reactive.py b/src/textual/reactive.py index ec6703835f..f2ad9cc586 100644 --- a/src/textual/reactive.py +++ b/src/textual/reactive.py @@ -24,6 +24,7 @@ from . import events from ._callback import count_parameters +from ._context import active_message_pump from ._types import ( MessageTarget, WatchCallbackBothValuesType, @@ -73,17 +74,23 @@ def invoke_watcher( """ _rich_traceback_omit = True param_count = count_parameters(watch_function) - if param_count == 2: - watch_result = cast(WatchCallbackBothValuesType, watch_function)( - old_value, value - ) - elif param_count == 1: - watch_result = cast(WatchCallbackNewValueType, watch_function)(value) - else: - watch_result = cast(WatchCallbackNoArgsType, watch_function)() - if isawaitable(watch_result): - # Result is awaitable, so we need to await it within an async context - watcher_object.call_next(partial(await_watcher, watcher_object, watch_result)) + reset_token = active_message_pump.set(watcher_object) + try: + if param_count == 2: + watch_result = cast(WatchCallbackBothValuesType, watch_function)( + old_value, value + ) + elif param_count == 1: + watch_result = cast(WatchCallbackNewValueType, watch_function)(value) + else: + watch_result = cast(WatchCallbackNoArgsType, watch_function)() + if isawaitable(watch_result): + # Result is awaitable, so we need to await it within an async context + watcher_object.call_next( + partial(await_watcher, watcher_object, watch_result) + ) + finally: + active_message_pump.reset(reset_token) @rich.repr.auto diff --git a/tests/test_reactive.py b/tests/test_reactive.py index 7bddb3df79..8a22d70a13 100644 --- a/tests/test_reactive.py +++ b/tests/test_reactive.py @@ -5,6 +5,8 @@ import pytest from textual.app import App, ComposeResult +from textual.message import Message +from textual.message_pump import MessagePump from textual.reactive import Reactive, TooManyComputesError, reactive, var from textual.widget import Widget @@ -705,3 +707,40 @@ def second_callback() -> None: assert logs == ["first", "second"] app.query_one(SomeWidget).test_var = 73 assert logs == ["first", "second", "first", "second"] + + +async def test_message_sender_from_reactive() -> None: + """Test that the sender of a message comes from the reacting widget.""" + + message_senders: list[MessagePump | None] = [] + + class TestWidget(Widget): + test_var: var[int] = var(0, init=False) + + class TestMessage(Message): + pass + + def watch_test_var(self) -> None: + self.post_message(self.TestMessage()) + + def make_reaction(self) -> None: + self.test_var += 1 + + class TestContainer(Widget): + def compose(self) -> ComposeResult: + yield TestWidget() + + def on_test_widget_test_message(self, event: TestWidget.TestMessage) -> None: + nonlocal message_senders + message_senders.append(event._sender) + + class TestApp(App[None]): + + def compose(self) -> ComposeResult: + yield TestContainer() + + async with TestApp().run_test() as pilot: + assert message_senders == [] + pilot.app.query_one(TestWidget).make_reaction() + await pilot.pause() + assert message_senders == [pilot.app.query_one(TestWidget)]