Skip to content

Commit

Permalink
Merge pull request #4894 from ZeroIntensity/disallow-screen-instances
Browse files Browse the repository at this point in the history
Disallow `Screen` instances in `App.SCREENS`
  • Loading branch information
willmcgugan authored Aug 22, 2024
2 parents 68ba2e9 + f739b77 commit 5573697
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 34 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

## Unreleased

- Disallowed `Screen` instances in `App.SCREENS` and `App.MODES`
- Fixed `App.MODES` being the same for all instances -- per-instance modes now exist internally


### Added

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/guide/screens/screen01.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def compose(self) -> ComposeResult:

class BSODApp(App):
CSS_PATH = "screen01.tcss"
SCREENS = {"bsod": BSOD()}
SCREENS = {"bsod": BSOD}
BINDINGS = [("b", "push_screen('bsod')", "BSOD")]


Expand Down
55 changes: 39 additions & 16 deletions src/textual/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ class App(Generic[ReturnType], DOMNode):
}
"""

MODES: ClassVar[dict[str, str | Screen | Callable[[], Screen]]] = {}
MODES: ClassVar[dict[str, str | Callable[[], Screen]]] = {}
"""Modes associated with the app and their base screens.
The base screen is the screen at the bottom of the mode stack. You can think of
Expand Down Expand Up @@ -324,7 +324,7 @@ class MyApp(App[None]):
...
```
"""
SCREENS: ClassVar[dict[str, Screen[Any] | Callable[[], Screen[Any]]]] = {}
SCREENS: ClassVar[dict[str, Callable[[], Screen[Any]]]] = {}
"""Screens associated with the app for the lifetime of the app."""

AUTO_FOCUS: ClassVar[str | None] = "*"
Expand Down Expand Up @@ -560,6 +560,8 @@ def __init__(

self._installed_screens: dict[str, Screen | Callable[[], Screen]] = {}
self._installed_screens.update(**self.SCREENS)
self._modes: dict[str, str | Callable[[], Screen]] = self.MODES.copy()
"""Contains the working-copy of the `MODES` for each instance."""

self._compose_stacks: list[list[Widget]] = []
self._composed: list[list[Widget]] = []
Expand Down Expand Up @@ -664,6 +666,24 @@ def __init__(
)
)

def __init_subclass__(cls, *args, **kwargs) -> None:
for variable_name, screen_collection in (
("SCREENS", cls.SCREENS),
("MODES", cls.MODES),
):
for screen_name, screen_object in screen_collection.items():
if not (isinstance(screen_object, str) or callable(screen_object)):
if isinstance(screen_object, Screen):
raise ValueError(
f"{variable_name} should contain a Screen type or callable, not an instance"
f" (got instance of {type(screen_object).__name__} for {screen_name!r})"
)
raise TypeError(
f"expected a callable or string, got {screen_object!r}"
)

return super().__init_subclass__(*args, **kwargs)

def validate_title(self, title: Any) -> str:
"""Make sure the title is set to a string."""
return str(title)
Expand Down Expand Up @@ -1901,7 +1921,12 @@ def _init_mode(self, mode: str) -> AwaitMount:
if stack:
await_mount = AwaitMount(stack[0], [])
else:
_screen = self.MODES[mode]
_screen = self._modes[mode]
if isinstance(_screen, Screen):
raise TypeError(
"MODES cannot contain instances, use a type instead "
f"(got instance of {type(_screen).__name__} for {mode!r})"
)
new_screen: Screen | str = _screen() if callable(_screen) else _screen
screen, await_mount = self._get_screen(new_screen)
stack.append(screen)
Expand All @@ -1923,7 +1948,7 @@ def switch_mode(self, mode: str) -> AwaitMount:
Raises:
UnknownModeError: If trying to switch to an unknown mode.
"""
if mode not in self.MODES:
if mode not in self._modes:
raise UnknownModeError(f"No known mode {mode!r}")

self.screen.post_message(events.ScreenSuspend())
Expand All @@ -1943,9 +1968,7 @@ def switch_mode(self, mode: str) -> AwaitMount:

return await_mount

def add_mode(
self, mode: str, base_screen: str | Screen | Callable[[], Screen]
) -> None:
def add_mode(self, mode: str, base_screen: str | Callable[[], Screen]) -> None:
"""Adds a mode and its corresponding base screen to the app.
Args:
Expand All @@ -1957,10 +1980,15 @@ def add_mode(
"""
if mode == "_default":
raise InvalidModeError("Cannot use '_default' as a custom mode.")
elif mode in self.MODES:
elif mode in self._modes:
raise InvalidModeError(f"Duplicated mode name {mode!r}.")

self.MODES[mode] = base_screen
if isinstance(base_screen, Screen):
raise TypeError(
"add_mode() must be called with a Screen type, not an instance"
f" (got instance of {type(base_screen).__name__})"
)
self._modes[mode] = base_screen

def remove_mode(self, mode: str) -> AwaitComplete:
"""Removes a mode from the app.
Expand All @@ -1976,10 +2004,10 @@ def remove_mode(self, mode: str) -> AwaitComplete:
"""
if mode == self._current_mode:
raise ActiveModeError(f"Can't remove active mode {mode!r}")
elif mode not in self.MODES:
elif mode not in self._modes:
raise UnknownModeError(f"Unknown mode {mode!r}")
else:
del self.MODES[mode]
del self._modes[mode]

if mode not in self._screen_stacks:
return AwaitComplete.nothing()
Expand Down Expand Up @@ -2860,11 +2888,6 @@ async def _close_all(self) -> None:
await self._prune(stack_screen)
stack.clear()

# Close pre-defined screens.
for screen in self.SCREENS.values():
if isinstance(screen, Screen) and screen._running:
await self._prune(screen)

# Close any remaining nodes
# Should be empty by now
remaining_nodes = list(self._registry)
Expand Down
14 changes: 7 additions & 7 deletions tests/css/test_screen_css.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ async def test_screen_css_push_screen_instance_by_name():
"""Check that screen CSS is loaded and applied when pushing a screen name that points to a screen instance."""

class MyApp(BaseApp):
SCREENS = {"screenwithcss": ScreenWithCSS()}
SCREENS = {"screenwithcss": ScreenWithCSS}

def key_p(self):
self.push_screen("screenwithcss")
Expand Down Expand Up @@ -187,7 +187,7 @@ async def test_screen_css_switch_screen_instance_by_name():
"""Check that screen CSS is loaded and applied when switching a screen name that points to a screen instance."""

class MyApp(SwitchBaseApp):
SCREENS = {"screenwithcss": ScreenWithCSS()}
SCREENS = {"screenwithcss": ScreenWithCSS}

def key_p(self):
self.switch_screen("screenwithcss")
Expand Down Expand Up @@ -230,8 +230,8 @@ async def test_screen_css_switch_mode_screen_instance():

class MyApp(BaseApp):
MODES = {
"base": BaseScreen(),
"mode": ScreenWithCSS(),
"base": BaseScreen,
"mode": ScreenWithCSS,
}

def key_p(self):
Expand All @@ -255,11 +255,11 @@ async def test_screen_css_switch_mode_screen_instance_by_name():

class MyApp(BaseApp):
SCREENS = {
"screenwithcss": ScreenWithCSS(),
"screenwithcss": ScreenWithCSS,
}

MODES = {
"base": BaseScreen(),
"base": BaseScreen,
"mode": "screenwithcss",
}

Expand Down Expand Up @@ -288,7 +288,7 @@ class MyApp(BaseApp):
}

MODES = {
"base": BaseScreen(),
"base": BaseScreen,
"mode": "screenwithcss",
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def compose(self) -> ComposeResult:
class NotifyThroughModesApp(App[None]):

MODES = {
"test": Mode()
"test": Mode
}

def compose(self) -> ComposeResult:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_screen_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ async def test_remove_mode(ModesApp: Type[App]):
await pilot.pause()
assert str(app.screen.query_one(Label).renderable) == "two"
app.remove_mode("one")
assert "one" not in app.MODES
assert "one" not in app._modes


async def test_remove_active_mode(ModesApp: Type[App]):
Expand All @@ -130,7 +130,7 @@ async def test_remove_active_mode(ModesApp: Type[App]):
async def test_add_mode(ModesApp: Type[App]):
app = ModesApp()
async with app.run_test() as pilot:
app.add_mode("three", BaseScreen("three"))
app.add_mode("three", lambda: BaseScreen("three"))
await app.switch_mode("three")
await pilot.pause()
assert str(app.screen.query_one(Label).renderable) == "three"
Expand All @@ -140,7 +140,7 @@ async def test_add_mode_duplicated(ModesApp: Type[App]):
app = ModesApp()
async with app.run_test():
with pytest.raises(InvalidModeError):
app.add_mode("one", BaseScreen("one"))
app.add_mode("one", lambda: BaseScreen("one"))


async def test_screen_stack_preserved(ModesApp: Type[App]):
Expand Down
41 changes: 35 additions & 6 deletions tests/test_screens.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def test_installed_screens():
class ScreensApp(App):
SCREENS = {
"home": Screen, # Screen type
"one": Screen(), # Screen instance
"one": Screen, # Screen instance, disallowed as of #4893
"two": Screen, # Callable[[], Screen]
}

Expand Down Expand Up @@ -354,7 +354,7 @@ class MyScreen(Screen):
pass

class MyApp(App[None]):
SCREENS = {"screen": MyScreen()}
SCREENS = {"screen": MyScreen}

def on_mount(self):
self.push_screen("screen")
Expand All @@ -379,8 +379,8 @@ class ScreenB(Screen):

class MyApp(App[None]):
SCREENS = {
"a": ScreenA(),
"b": ScreenB(),
"a": ScreenA,
"b": ScreenB,
}

def callback(self, _):
Expand All @@ -407,7 +407,7 @@ def on_mouse_move(self, event: MouseMove) -> None:
MouseMoveRecordingScreen.mouse_events.append(event)

class SimpleApp(App[None]):
SCREENS = {"a": MouseMoveRecordingScreen()}
SCREENS = {"a": MouseMoveRecordingScreen}

def on_mount(self):
self.push_screen("a")
Expand Down Expand Up @@ -439,7 +439,7 @@ def on_mouse_move(self, event: MouseMove) -> None:
MouseMoveRecordingScreen.mouse_events.append(event)

class SimpleApp(App[None]):
SCREENS = {"a": MouseMoveRecordingScreen()}
SCREENS = {"a": MouseMoveRecordingScreen}

def on_mount(self):
self.push_screen("a")
Expand Down Expand Up @@ -539,6 +539,35 @@ def get_default_screen(self) -> Screen:
assert app.screen is app.screen_stack[0]


async def test_disallow_screen_instances() -> None:
"""Test that screen instances are disallowed."""

class CustomScreen(Screen):
pass

with pytest.raises(ValueError):

class Bad(App):
SCREENS = {"a": CustomScreen()} # type: ignore

with pytest.raises(ValueError):

class Worse(App):
MODES = {"a": CustomScreen()} # type: ignore

# While we're here, let's make sure that other types
# are disallowed.
with pytest.raises(TypeError):

class Terrible(App):
MODES = {"a": 42, "b": CustomScreen} # type: ignore

with pytest.raises(TypeError):

class Worst(App):
MODES = {"OK": CustomScreen, 1: 2} # type: ignore


async def test_worker_cancellation():
"""Regression test for https://github.com/Textualize/textual/issues/4884
Expand Down

0 comments on commit 5573697

Please sign in to comment.