diff --git a/CHANGELOG.md b/CHANGELOG.md index 525ad56fe7..55d4ccf6a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Reactive `cell_padding` (and respective parameter) to define horizontal cell padding in data table columns https://github.com/Textualize/textual/issues/3435 - Added `Input.clear` method https://github.com/Textualize/textual/pull/3430 - Added `TextArea.SelectionChanged` and `TextArea.Changed` messages https://github.com/Textualize/textual/pull/3442 +- Added `wait_for_dismiss` parameter to `App.push_screen` https://github.com/Textualize/textual/pull/3477 ### Changed diff --git a/src/textual/app.py b/src/textual/app.py index bb06942226..7e6c174722 100644 --- a/src/textual/app.py +++ b/src/textual/app.py @@ -99,10 +99,11 @@ from .screen import Screen, ScreenResultCallbackType, ScreenResultType from .widget import AwaitMount, Widget from .widgets._toast import ToastRack +from .worker import NoActiveWorker, get_current_worker if TYPE_CHECKING: from textual_dev.client import DevtoolsClient - from typing_extensions import Coroutine, TypeAlias + from typing_extensions import Coroutine, Literal, TypeAlias from ._types import MessageTarget @@ -255,17 +256,14 @@ class App(Generic[ReturnType], DOMNode): and therefore takes priority in the event of a specificity clash.""" # Default (the lowest priority) CSS - DEFAULT_CSS: ClassVar[ - str - ] = """ + DEFAULT_CSS: ClassVar[str] + DEFAULT_CSS = """ App { background: $background; color: $text; } - *:disabled:can-focus { opacity: 0.7; - } """ @@ -1778,25 +1776,58 @@ def _replace_screen(self, screen: Screen) -> Screen: self.log.system(f"{screen} REMOVED") return screen + @overload def push_screen( self, screen: Screen[ScreenResultType] | str, callback: ScreenResultCallbackType[ScreenResultType] | None = None, + wait_for_dismiss: Literal[False] = False, ) -> AwaitMount: + ... + + @overload + def push_screen( + self, + screen: Screen[ScreenResultType] | str, + callback: ScreenResultCallbackType[ScreenResultType] | None = None, + wait_for_dismiss: Literal[True] = True, + ) -> asyncio.Future[ScreenResultType]: + ... + + def push_screen( + self, + screen: Screen[ScreenResultType] | str, + callback: ScreenResultCallbackType[ScreenResultType] | None = None, + wait_for_dismiss: bool = False, + ) -> AwaitMount | asyncio.Future[ScreenResultType]: """Push a new [screen](/guide/screens) on the screen stack, making it the current screen. Args: screen: A Screen instance or the name of an installed screen. callback: An optional callback function that will be called if the screen is [dismissed][textual.screen.Screen.dismiss] with a result. + wait_for_dismiss: If `True`, awaiting this method will return the dismiss value from the screen. When set to `False`, awaiting + this method will wait for the screen to be mounted. Note that `wait_for_dismiss` should only be set to `True` when running in a worker. + + Raises: + NoActiveWorker: If using `wait_for_dismiss` outside of a worker. Returns: - An optional awaitable that awaits the mounting of the screen and its children. + An optional awaitable that awaits the mounting of the screen and its children, or an asyncio Future + to await the result of the screen. """ if not isinstance(screen, (Screen, str)): raise TypeError( f"push_screen requires a Screen instance or str; not {screen!r}" ) + try: + loop = asyncio.get_running_loop() + except RuntimeError: + # Mainly for testing, when push_screen isn't called in an async context + future: asyncio.Future[ScreenResultType] = asyncio.Future() + else: + future = loop.create_future() + if self._screen_stack: self.screen.post_message(events.ScreenSuspend()) self.screen.refresh() @@ -1805,13 +1836,23 @@ def push_screen( message_pump = active_message_pump.get() except LookupError: message_pump = self.app - next_screen._push_result_callback(message_pump, callback) + + next_screen._push_result_callback(message_pump, callback, future) self._load_screen_css(next_screen) self._screen_stack.append(next_screen) self.stylesheet.update(next_screen) next_screen.post_message(events.ScreenResume()) self.log.system(f"{self.screen} is current (PUSHED)") - return await_mount + if wait_for_dismiss: + try: + get_current_worker() + except NoActiveWorker: + raise NoActiveWorker( + "push_screen must be run from a worker when `wait_for_dismiss` is True" + ) from None + return future + else: + return await_mount def switch_screen(self, screen: Screen | str) -> AwaitMount: """Switch to another [screen](/guide/screens) by replacing the top of the screen stack with a new screen. diff --git a/src/textual/screen.py b/src/textual/screen.py index c5404bc05f..be09b66e7c 100644 --- a/src/textual/screen.py +++ b/src/textual/screen.py @@ -5,6 +5,7 @@ from __future__ import annotations +import asyncio from functools import partial from operator import attrgetter from typing import ( @@ -74,17 +75,21 @@ def __init__( self, requester: MessagePump, callback: ScreenResultCallbackType[ScreenResultType] | None, + future: asyncio.Future[ScreenResultType] | None = None, ) -> None: """Initialise the result callback object. Args: requester: The object making a request for the callback. callback: The callback function. + future: A Future to hold the result. """ self.requester = requester """The object in the DOM that requested the callback.""" self.callback: ScreenResultCallbackType | None = callback """The callback function.""" + self.future = future + """A future for the result""" def __call__(self, result: ScreenResultType) -> None: """Call the callback, passing the given result. @@ -95,6 +100,8 @@ def __call__(self, result: ScreenResultType) -> None: Note: If the requested or the callback are `None` this will be a no-op. """ + if self.future is not None: + self.future.set_result(result) if self.requester is not None and self.callback is not None: self.requester.call_next(self.callback, result) @@ -687,15 +694,17 @@ def _push_result_callback( self, requester: MessagePump, callback: ScreenResultCallbackType[ScreenResultType] | None, + future: asyncio.Future[ScreenResultType] | None = None, ) -> None: """Add a result callback to the screen. Args: requester: The object requesting the callback. callback: The callback. + future: A Future to hold the result. """ self._result_callbacks.append( - ResultCallback[ScreenResultType](requester, callback) + ResultCallback[ScreenResultType](requester, callback, future) ) def _pop_result_callback(self) -> None: diff --git a/tests/test_screens.py b/tests/test_screens.py index 5f587fd0ee..e5ddacb4d7 100644 --- a/tests/test_screens.py +++ b/tests/test_screens.py @@ -4,11 +4,13 @@ import pytest -from textual.app import App, ScreenStackError, ComposeResult +from textual import work +from textual.app import App, ComposeResult, ScreenStackError from textual.events import MouseMove from textual.geometry import Offset from textual.screen import Screen from textual.widgets import Button, Input, Label +from textual.worker import NoActiveWorker skip_py310 = pytest.mark.skipif( sys.version_info.minor == 10 and sys.version_info.major == 3, @@ -407,4 +409,71 @@ def on_mount(self): assert len(MouseMoveRecordingScreen.mouse_events) == 1 mouse_event = MouseMoveRecordingScreen.mouse_events[0] - assert mouse_event.x, mouse_event.y == (label_offset.x + mouse_offset.x, label_offset.y + mouse_offset.y) + assert mouse_event.x, mouse_event.y == ( + label_offset.x + mouse_offset.x, + label_offset.y + mouse_offset.y, + ) + + +async def test_push_screen_wait_for_dismiss() -> None: + """Test push_screen returns result.""" + + class QuitScreen(Screen[bool]): + BINDINGS = [ + ("y", "quit(True)"), + ("n", "quit(False)"), + ] + + def action_quit(self, quit: bool) -> None: + self.dismiss(quit) + + results: list[bool] = [] + + class ScreensApp(App): + BINDINGS = [("x", "exit")] + + @work + async def action_exit(self) -> None: + result = await self.push_screen(QuitScreen(), wait_for_dismiss=True) + results.append(result) + + app = ScreensApp() + # Press X to exit, then Y to dismiss, expect True result + async with app.run_test() as pilot: + await pilot.press("x", "y") + assert results == [True] + + results.clear() + app = ScreensApp() + # Press X to exit, then N to dismiss, expect False result + async with app.run_test() as pilot: + await pilot.press("x", "n") + assert results == [False] + + +async def test_push_screen_wait_for_dismiss_no_worker() -> None: + """Test wait_for_dismiss raises NoActiveWorker when not using workers.""" + + class QuitScreen(Screen[bool]): + BINDINGS = [ + ("y", "quit(True)"), + ("n", "quit(False)"), + ] + + def action_quit(self, quit: bool) -> None: + self.dismiss(quit) + + results: list[bool] = [] + + class ScreensApp(App): + BINDINGS = [("x", "exit")] + + async def action_exit(self) -> None: + result = await self.push_screen(QuitScreen(), wait_for_dismiss=True) + results.append(result) + + app = ScreensApp() + # using `wait_for_dismiss` outside of a worker should raise NoActiveWorker + with pytest.raises(NoActiveWorker): + async with app.run_test() as pilot: + await pilot.press("x", "y")