Skip to content

Commit

Permalink
await screens
Browse files Browse the repository at this point in the history
  • Loading branch information
willmcgugan committed Jun 24, 2024
1 parent c9c34c7 commit ee71b34
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 42 deletions.
6 changes: 4 additions & 2 deletions src/textual/_worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,5 +175,7 @@ async def wait_for_complete(self, workers: Iterable[Worker] | None = None) -> No
Args:
workers: An iterable of workers or None to wait for all workers in the manager.
"""

await asyncio.gather(*[worker.wait() for worker in (workers or self)])
try:
await asyncio.gather(*[worker.wait() for worker in (workers or self)])
except asyncio.CancelledError:
pass
73 changes: 45 additions & 28 deletions src/textual/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from ._wait import wait_for_idle
from ._worker_manager import WorkerManager
from .actions import ActionParseResult, SkipAction
from .await_complete import AwaitComplete
from .await_remove import AwaitRemove
from .binding import Binding, BindingType, _Bindings
from .command import CommandPalette, Provider
Expand Down Expand Up @@ -1569,8 +1570,10 @@ async def run_auto_pilot(
if auto_pilot_task is not None:
await auto_pilot_task
finally:
await app._shutdown()

try:
await asyncio.shield(app._shutdown())
except asyncio.CancelledError:
pass
return app.return_value

def run(
Expand Down Expand Up @@ -1910,7 +1913,7 @@ def add_mode(

self.MODES[mode] = base_screen

def remove_mode(self, mode: str) -> None:
def remove_mode(self, mode: str) -> AwaitComplete:
"""Removes a mode from the app.
Screens that are running in the stack of that mode are scheduled for pruning.
Expand All @@ -1930,12 +1933,16 @@ def remove_mode(self, mode: str) -> None:
del self.MODES[mode]

if mode not in self._screen_stacks:
return
return AwaitComplete.nothing()

stack = self._screen_stacks[mode]
del self._screen_stacks[mode]
for screen in reversed(stack):
self._replace_screen(screen)

async def remove_screens():
for screen in reversed(stack):
await self._replace_screen(screen)

return AwaitComplete(remove_screens()).call_next(self)

def is_screen_installed(self, screen: Screen | str) -> bool:
"""Check if a given screen has been installed.
Expand Down Expand Up @@ -2030,7 +2037,7 @@ def _load_screen_css(self, screen: Screen):
self.stylesheet.reparse()
self.stylesheet.update(self)

def _replace_screen(self, screen: Screen) -> Screen:
async def _replace_screen(self, screen: Screen) -> Screen:
"""Handle the replaced screen.
Args:
Expand All @@ -2046,7 +2053,7 @@ def _replace_screen(self, screen: Screen) -> Screen:
if not self.is_screen_installed(screen) and all(
screen not in stack for stack in self._screen_stacks.values()
):
screen.remove()
await screen.remove()
self.log.system(f"{screen} REMOVED")
return screen

Expand Down Expand Up @@ -2151,9 +2158,10 @@ async def push_screen_wait(
Returns:
The screen's result.
"""
await self._flush_next_callbacks()
return await self.push_screen(screen, wait_for_dismiss=True)

def switch_screen(self, screen: Screen | str) -> AwaitMount:
def switch_screen(self, screen: Screen | str) -> AwaitComplete:
"""Switch to another [screen](/guide/screens) by replacing the top of the screen stack with a new screen.
Args:
Expand All @@ -2164,19 +2172,23 @@ def switch_screen(self, screen: Screen | str) -> AwaitMount:
f"switch_screen requires a Screen instance or str; not {screen!r}"
)

next_screen, await_mount = self._get_screen(screen)
if screen is self.screen or next_screen is self.screen:
self.log.system(f"Screen {screen} is already current.")
return AwaitMount(self.screen, [])
async def do_switch() -> None:
next_screen, await_mount = self._get_screen(screen)
if screen is self.screen or next_screen is self.screen:
self.log.system(f"Screen {screen} is already current.")
return

previous_screen = self._replace_screen(self._screen_stack.pop())
previous_screen._pop_result_callback()
self._load_screen_css(next_screen)
self._screen_stack.append(next_screen)
self.screen.post_message(events.ScreenResume())
self.screen._push_result_callback(self.screen, None)
self.log.system(f"{self.screen} is current (SWITCHED)")
return await_mount
await await_mount()

previous_screen = await self._replace_screen(self._screen_stack.pop())
previous_screen._pop_result_callback()
self._load_screen_css(next_screen)
self._screen_stack.append(next_screen)
self.screen.post_message(events.ScreenResume())
self.screen._push_result_callback(self.screen, None)
self.log.system(f"{self.screen} is current (SWITCHED)")

return AwaitComplete(do_switch()).call_next(self)

def install_screen(self, screen: Screen, name: str) -> None:
"""Install a screen.
Expand Down Expand Up @@ -2238,22 +2250,26 @@ def uninstall_screen(self, screen: Screen | str) -> str | None:
return name
return None

def pop_screen(self) -> Screen[object]:
def pop_screen(self) -> AwaitComplete:
"""Pop the current [screen](/guide/screens) from the stack, and switch to the previous screen.
Returns:
The screen that was replaced.
"""

screen_stack = self._screen_stack
if len(screen_stack) <= 1:
raise ScreenStackError(
"Can't pop screen; there must be at least one screen on the stack"
)
previous_screen = self._replace_screen(screen_stack.pop())
previous_screen._pop_result_callback()
self.screen.post_message(events.ScreenResume())
self.log.system(f"{self.screen} is active")
return previous_screen

async def do_pop() -> None:
previous_screen = await self._replace_screen(screen_stack.pop())
previous_screen._pop_result_callback()
self.screen.post_message(events.ScreenResume())
self.log.system(f"{self.screen} is active")

return AwaitComplete(do_pop()).call_next(self)

def set_focus(self, widget: Widget | None, scroll_visible: bool = True) -> None:
"""Focus (or unfocus) a widget. A focused widget will receive key events first.
Expand Down Expand Up @@ -2354,6 +2370,7 @@ def _handle_exception(self, error: Exception) -> None:
Args:
error: An exception instance.
"""
self.log.error(error)
self._return_code = 1
# If we're running via pilot and this is the first exception encountered,
# take note of it so that we can re-raise for test frameworks later.
Expand Down Expand Up @@ -3408,7 +3425,7 @@ async def _prune_nodes(self, widgets: list[Widget]) -> None:
"""
async with self._dom_lock:
for widget in widgets:
await self._prune_node(widget)
await asyncio.shield(self._prune_node(widget))

async def _prune_node(self, root: Widget) -> None:
"""Remove a node and its children. Children are removed before parents.
Expand Down
12 changes: 12 additions & 0 deletions src/textual/await_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from typing import Any, Awaitable, Generator

import rich.repr
from typing_extensions import Self

from .message_pump import MessagePump


@rich.repr.auto(angular=True)
Expand All @@ -18,6 +21,15 @@ def __init__(self, *awaitables: Awaitable) -> None:
"""
self._future: Future[Any] = gather(*awaitables)

def call_next(self, node: MessagePump) -> Self:
"""Await after the next message.
Args:
node: The node which created the object.
"""
node.call_next(self)
return self

async def __call__(self) -> Any:
return await self

Expand Down
5 changes: 4 additions & 1 deletion src/textual/message_pump.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,10 @@ async def _close_messages(self, wait: bool = True) -> None:
running_widget = None

if running_widget is None or running_widget is not self:
await self._task
try:
await self._task
except CancelledError:
pass

def _start_messages(self) -> None:
"""Start messages task."""
Expand Down
22 changes: 13 additions & 9 deletions src/textual/screen.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ._context import active_message_pump, visible_screen_stack
from ._path import CSSPathType, _css_path_type_as_list, _make_path_object_relative
from ._types import CallbackType
from .await_complete import AwaitComplete
from .binding import ActiveBinding, Binding, _Bindings
from .css.match import match
from .css.parse import parse_selectors
Expand Down Expand Up @@ -1226,43 +1227,46 @@ def _forward_event(self, event: events.Event) -> None:
class _NoResult:
"""Class used to mark that there is no result."""

def dismiss(self, result: ScreenResultType | Type[_NoResult] = _NoResult) -> bool:
def dismiss(
self, result: ScreenResultType | Type[_NoResult] = _NoResult
) -> AwaitComplete:
"""Dismiss the screen, optionally with a result.
!!! note
Only the active screen may be dismissed. If you try to dismiss a screen that isn't active,
this method will return `False`.
this method will raise a `ScreenError`.
If `result` is provided and a callback was set when the screen was [pushed][textual.app.App.push_screen], then
the callback will be invoked with `result`.
Args:
result: The optional result to be passed to the result callback.
Returns:
`True` if the Screen was dismissed, or `False` if the Screen wasn't dismissed due to not being active.
Raises:
ScreenError: If the screen being dismissed is not active.
ScreenStackError: If trying to dismiss a screen that is not at the top of
the stack.
"""
if not self.is_active:
return False
from .app import ScreenError

raise ScreenError("Screen is not active")
if result is not self._NoResult and self._result_callbacks:
self._result_callbacks[-1](cast(ScreenResultType, result))
self.app.pop_screen()
return True
await_pop = self.app.pop_screen()
return await_pop

def action_dismiss(
async def action_dismiss(
self, result: ScreenResultType | Type[_NoResult] = _NoResult
) -> None:
"""A wrapper around [`dismiss`][textual.screen.Screen.dismiss] that can be called as an action.
Args:
result: The optional result to be passed to the result callback.
"""
await self._flush_next_callbacks()
self.dismiss(result)

def can_view(self, widget: Widget) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion src/textual/widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ def set_loading(self, loading: bool) -> Awaitable:
loading_indicator = self.get_loading_widget()
loading_indicator.add_class(LOADING_INDICATOR_CLASS)
await_mount = self.mount(loading_indicator)
return AwaitComplete(remove_indicator, await_mount)
return AwaitComplete(remove_indicator, await_mount).call_next(self)
else:
return remove_indicator

Expand Down
2 changes: 1 addition & 1 deletion src/textual/widgets/_directory_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def _add_to_load_queue(self, node: TreeNode[DirEntry]) -> AwaitComplete:
node.data.loaded = True
self._load_queue.put_nowait(node)

return AwaitComplete(self._load_queue.join())
return AwaitComplete(self, self._load_queue.join())

def reload(self) -> AwaitComplete:
"""Reload the `DirectoryTree` contents.
Expand Down

0 comments on commit ee71b34

Please sign in to comment.