Skip to content

Commit

Permalink
Merge pull request #5025 from Textualize/expected-screen
Browse files Browse the repository at this point in the history
overload get_screen
  • Loading branch information
willmcgugan authored Sep 20, 2024
2 parents cbef6cd + 52fff84 commit 14794db
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 1 deletion.
33 changes: 32 additions & 1 deletion src/textual/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@
CommandCallback: TypeAlias = "Callable[[], Awaitable[Any]] | Callable[[], Any]"
"""Signature for callbacks used in [`get_system_commands`][textual.app.App.get_system_commands]"""

ScreenType = TypeVar("ScreenType", bound=Screen)
"""Type var for a Screen, used in [`get_screen`][textual.app.App.get_screen]."""


class SystemCommand(NamedTuple):
"""Defines a system command used in the command palette (yielded from [`get_system_commands`][textual.app.App.get_system_commands])."""
Expand Down Expand Up @@ -2263,11 +2266,35 @@ def is_screen_installed(self, screen: Screen | str) -> bool:
else:
return screen in self._installed_screens.values()

def get_screen(self, screen: Screen | str) -> Screen:
@overload
def get_screen(self, screen: ScreenType) -> ScreenType: ...

@overload
def get_screen(self, screen: str) -> Screen: ...

@overload
def get_screen(
self, screen: str, screen_class: Type[ScreenType] | None = None
) -> ScreenType: ...

@overload
def get_screen(
self, screen: ScreenType, screen_class: Type[ScreenType] | None = None
) -> ScreenType: ...

def get_screen(
self, screen: Screen | str, screen_class: Type[Screen] | None = None
) -> Screen:
"""Get an installed screen.
Example:
```python
my_screen = self.get_screen("settings", MyScreen)
```
Args:
screen: Either a Screen object or screen name (the `name` argument when installed).
screen_class: Class of expected screen, or `None` for any screen class.
Raises:
KeyError: If the named screen doesn't exist.
Expand All @@ -2285,6 +2312,10 @@ def get_screen(self, screen: Screen | str) -> Screen:
self._installed_screens[screen] = next_screen
else:
next_screen = screen
if screen_class is not None and not isinstance(next_screen, screen_class):
raise TypeError(
f"Expected a screen of type {screen_class}, got {type(next_screen)}"
)
return next_screen

def _get_screen(self, screen: Screen | str) -> tuple[Screen, AwaitMount]:
Expand Down
3 changes: 3 additions & 0 deletions src/textual/widgets/directory_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from textual.widgets._directory_tree import DirEntry

__all__ = ["DirEntry"]
32 changes: 32 additions & 0 deletions tests/test_screens.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,3 +624,35 @@ async def action_info(self) -> None:
# Press enter to activate button to dismiss them
await pilot.press("enter")
await pilot.press("enter")


async def test_get_screen_with_expected_type():
"""Test get_screen with expected type works"""

class BadScreen(Screen[None]):
pass

class MyScreen(Screen[None]):
def compose(self):
yield Label()
yield Button()

class MyApp(App[None]):
SCREENS = {"my_screen": MyScreen}

def on_mount(self):
self.push_screen("my_screen")

app = MyApp()
async with app.run_test():
screen = app.get_screen("my_screen")
# Should be fine
assert isinstance(screen, MyScreen)

screen = app.get_screen("my_screen", MyScreen)
# Should be fine
assert isinstance(screen, MyScreen)

# TypeError because my_screen is not a BadScreen
with pytest.raises(TypeError):
screen = app.get_screen("my_screen", BadScreen)

0 comments on commit 14794db

Please sign in to comment.