Skip to content

Commit

Permalink
Add wait_for_dismiss to push_screen (#3477)
Browse files Browse the repository at this point in the history
* docstrings

* raises docstring

* fix for tests

* Formatting

* tests

* changelog

* simplify

* typing

* dot

* typing

* Update tests/test_screens.py

Co-authored-by: Rodrigo Girão Serrão <[email protected]>

---------

Co-authored-by: Rodrigo Girão Serrão <[email protected]>
  • Loading branch information
willmcgugan and rodrigogiraoserrao authored Oct 9, 2023
1 parent 005f556 commit b8ac737
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- Reactive `cell_padding` (and respective parameter) to define horizontal cell padding in data table columns https://github.com/Textualize/textual/issues/3435
- Added `Input.clear` method https://github.com/Textualize/textual/pull/3430
- Added `TextArea.SelectionChanged` and `TextArea.Changed` messages https://github.com/Textualize/textual/pull/3442
- Added `wait_for_dismiss` parameter to `App.push_screen` https://github.com/Textualize/textual/pull/3477

### Changed

Expand Down
59 changes: 50 additions & 9 deletions src/textual/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,11 @@
from .screen import Screen, ScreenResultCallbackType, ScreenResultType
from .widget import AwaitMount, Widget
from .widgets._toast import ToastRack
from .worker import NoActiveWorker, get_current_worker

if TYPE_CHECKING:
from textual_dev.client import DevtoolsClient
from typing_extensions import Coroutine, TypeAlias
from typing_extensions import Coroutine, Literal, TypeAlias

from ._types import MessageTarget

Expand Down Expand Up @@ -255,17 +256,14 @@ class App(Generic[ReturnType], DOMNode):
and therefore takes priority in the event of a specificity clash."""

# Default (the lowest priority) CSS
DEFAULT_CSS: ClassVar[
str
] = """
DEFAULT_CSS: ClassVar[str]
DEFAULT_CSS = """
App {
background: $background;
color: $text;
}
*:disabled:can-focus {
opacity: 0.7;
}
"""

Expand Down Expand Up @@ -1778,25 +1776,58 @@ def _replace_screen(self, screen: Screen) -> Screen:
self.log.system(f"{screen} REMOVED")
return screen

@overload
def push_screen(
self,
screen: Screen[ScreenResultType] | str,
callback: ScreenResultCallbackType[ScreenResultType] | None = None,
wait_for_dismiss: Literal[False] = False,
) -> AwaitMount:
...

@overload
def push_screen(
self,
screen: Screen[ScreenResultType] | str,
callback: ScreenResultCallbackType[ScreenResultType] | None = None,
wait_for_dismiss: Literal[True] = True,
) -> asyncio.Future[ScreenResultType]:
...

def push_screen(
self,
screen: Screen[ScreenResultType] | str,
callback: ScreenResultCallbackType[ScreenResultType] | None = None,
wait_for_dismiss: bool = False,
) -> AwaitMount | asyncio.Future[ScreenResultType]:
"""Push a new [screen](/guide/screens) on the screen stack, making it the current screen.
Args:
screen: A Screen instance or the name of an installed screen.
callback: An optional callback function that will be called if the screen is [dismissed][textual.screen.Screen.dismiss] with a result.
wait_for_dismiss: If `True`, awaiting this method will return the dismiss value from the screen. When set to `False`, awaiting
this method will wait for the screen to be mounted. Note that `wait_for_dismiss` should only be set to `True` when running in a worker.
Raises:
NoActiveWorker: If using `wait_for_dismiss` outside of a worker.
Returns:
An optional awaitable that awaits the mounting of the screen and its children.
An optional awaitable that awaits the mounting of the screen and its children, or an asyncio Future
to await the result of the screen.
"""
if not isinstance(screen, (Screen, str)):
raise TypeError(
f"push_screen requires a Screen instance or str; not {screen!r}"
)

try:
loop = asyncio.get_running_loop()
except RuntimeError:
# Mainly for testing, when push_screen isn't called in an async context
future: asyncio.Future[ScreenResultType] = asyncio.Future()
else:
future = loop.create_future()

if self._screen_stack:
self.screen.post_message(events.ScreenSuspend())
self.screen.refresh()
Expand All @@ -1805,13 +1836,23 @@ def push_screen(
message_pump = active_message_pump.get()
except LookupError:
message_pump = self.app
next_screen._push_result_callback(message_pump, callback)

next_screen._push_result_callback(message_pump, callback, future)
self._load_screen_css(next_screen)
self._screen_stack.append(next_screen)
self.stylesheet.update(next_screen)
next_screen.post_message(events.ScreenResume())
self.log.system(f"{self.screen} is current (PUSHED)")
return await_mount
if wait_for_dismiss:
try:
get_current_worker()
except NoActiveWorker:
raise NoActiveWorker(
"push_screen must be run from a worker when `wait_for_dismiss` is True"
) from None
return future
else:
return await_mount

def switch_screen(self, screen: Screen | str) -> AwaitMount:
"""Switch to another [screen](/guide/screens) by replacing the top of the screen stack with a new screen.
Expand Down
11 changes: 10 additions & 1 deletion src/textual/screen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import asyncio
from functools import partial
from operator import attrgetter
from typing import (
Expand Down Expand Up @@ -74,17 +75,21 @@ def __init__(
self,
requester: MessagePump,
callback: ScreenResultCallbackType[ScreenResultType] | None,
future: asyncio.Future[ScreenResultType] | None = None,
) -> None:
"""Initialise the result callback object.
Args:
requester: The object making a request for the callback.
callback: The callback function.
future: A Future to hold the result.
"""
self.requester = requester
"""The object in the DOM that requested the callback."""
self.callback: ScreenResultCallbackType | None = callback
"""The callback function."""
self.future = future
"""A future for the result"""

def __call__(self, result: ScreenResultType) -> None:
"""Call the callback, passing the given result.
Expand All @@ -95,6 +100,8 @@ def __call__(self, result: ScreenResultType) -> None:
Note:
If the requested or the callback are `None` this will be a no-op.
"""
if self.future is not None:
self.future.set_result(result)
if self.requester is not None and self.callback is not None:
self.requester.call_next(self.callback, result)

Expand Down Expand Up @@ -687,15 +694,17 @@ def _push_result_callback(
self,
requester: MessagePump,
callback: ScreenResultCallbackType[ScreenResultType] | None,
future: asyncio.Future[ScreenResultType] | None = None,
) -> None:
"""Add a result callback to the screen.
Args:
requester: The object requesting the callback.
callback: The callback.
future: A Future to hold the result.
"""
self._result_callbacks.append(
ResultCallback[ScreenResultType](requester, callback)
ResultCallback[ScreenResultType](requester, callback, future)
)

def _pop_result_callback(self) -> None:
Expand Down
73 changes: 71 additions & 2 deletions tests/test_screens.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

import pytest

from textual.app import App, ScreenStackError, ComposeResult
from textual import work
from textual.app import App, ComposeResult, ScreenStackError
from textual.events import MouseMove
from textual.geometry import Offset
from textual.screen import Screen
from textual.widgets import Button, Input, Label
from textual.worker import NoActiveWorker

skip_py310 = pytest.mark.skipif(
sys.version_info.minor == 10 and sys.version_info.major == 3,
Expand Down Expand Up @@ -407,4 +409,71 @@ def on_mount(self):

assert len(MouseMoveRecordingScreen.mouse_events) == 1
mouse_event = MouseMoveRecordingScreen.mouse_events[0]
assert mouse_event.x, mouse_event.y == (label_offset.x + mouse_offset.x, label_offset.y + mouse_offset.y)
assert mouse_event.x, mouse_event.y == (
label_offset.x + mouse_offset.x,
label_offset.y + mouse_offset.y,
)


async def test_push_screen_wait_for_dismiss() -> None:
"""Test push_screen returns result."""

class QuitScreen(Screen[bool]):
BINDINGS = [
("y", "quit(True)"),
("n", "quit(False)"),
]

def action_quit(self, quit: bool) -> None:
self.dismiss(quit)

results: list[bool] = []

class ScreensApp(App):
BINDINGS = [("x", "exit")]

@work
async def action_exit(self) -> None:
result = await self.push_screen(QuitScreen(), wait_for_dismiss=True)
results.append(result)

app = ScreensApp()
# Press X to exit, then Y to dismiss, expect True result
async with app.run_test() as pilot:
await pilot.press("x", "y")
assert results == [True]

results.clear()
app = ScreensApp()
# Press X to exit, then N to dismiss, expect False result
async with app.run_test() as pilot:
await pilot.press("x", "n")
assert results == [False]


async def test_push_screen_wait_for_dismiss_no_worker() -> None:
"""Test wait_for_dismiss raises NoActiveWorker when not using workers."""

class QuitScreen(Screen[bool]):
BINDINGS = [
("y", "quit(True)"),
("n", "quit(False)"),
]

def action_quit(self, quit: bool) -> None:
self.dismiss(quit)

results: list[bool] = []

class ScreensApp(App):
BINDINGS = [("x", "exit")]

async def action_exit(self) -> None:
result = await self.push_screen(QuitScreen(), wait_for_dismiss=True)
results.append(result)

app = ScreensApp()
# using `wait_for_dismiss` outside of a worker should raise NoActiveWorker
with pytest.raises(NoActiveWorker):
async with app.run_test() as pilot:
await pilot.press("x", "y")

0 comments on commit b8ac737

Please sign in to comment.