diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ea23b3031..ebdd65ff34 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## Unreleased +- Disallowed `Screen` instances in `App.SCREENS` and `App.MODES` +- Fixed `App.MODES` being the same for all instances -- per-instance modes now exist internally + ### Added diff --git a/docs/examples/guide/screens/screen01.py b/docs/examples/guide/screens/screen01.py index 568c98d97a..91592e3692 100644 --- a/docs/examples/guide/screens/screen01.py +++ b/docs/examples/guide/screens/screen01.py @@ -25,7 +25,7 @@ def compose(self) -> ComposeResult: class BSODApp(App): CSS_PATH = "screen01.tcss" - SCREENS = {"bsod": BSOD()} + SCREENS = {"bsod": BSOD} BINDINGS = [("b", "push_screen('bsod')", "BSOD")] diff --git a/src/textual/app.py b/src/textual/app.py index c12ba3cdaa..1f1a32b824 100644 --- a/src/textual/app.py +++ b/src/textual/app.py @@ -295,7 +295,7 @@ class App(Generic[ReturnType], DOMNode): } """ - MODES: ClassVar[dict[str, str | Screen | Callable[[], Screen]]] = {} + MODES: ClassVar[dict[str, str | Callable[[], Screen]]] = {} """Modes associated with the app and their base screens. The base screen is the screen at the bottom of the mode stack. You can think of @@ -324,7 +324,7 @@ class MyApp(App[None]): ... ``` """ - SCREENS: ClassVar[dict[str, Screen[Any] | Callable[[], Screen[Any]]]] = {} + SCREENS: ClassVar[dict[str, Callable[[], Screen[Any]]]] = {} """Screens associated with the app for the lifetime of the app.""" AUTO_FOCUS: ClassVar[str | None] = "*" @@ -560,6 +560,8 @@ def __init__( self._installed_screens: dict[str, Screen | Callable[[], Screen]] = {} self._installed_screens.update(**self.SCREENS) + self._modes: dict[str, str | Callable[[], Screen]] = self.MODES.copy() + """Contains the working-copy of the `MODES` for each instance.""" self._compose_stacks: list[list[Widget]] = [] self._composed: list[list[Widget]] = [] @@ -664,6 +666,24 @@ def __init__( ) ) + def __init_subclass__(cls, *args, **kwargs) -> None: + for variable_name, screen_collection in ( + ("SCREENS", cls.SCREENS), + ("MODES", cls.MODES), + ): + for screen_name, screen_object in screen_collection.items(): + if not (isinstance(screen_object, str) or callable(screen_object)): + if isinstance(screen_object, Screen): + raise ValueError( + f"{variable_name} should contain a Screen type or callable, not an instance" + f" (got instance of {type(screen_object).__name__} for {screen_name!r})" + ) + raise TypeError( + f"expected a callable or string, got {screen_object!r}" + ) + + return super().__init_subclass__(*args, **kwargs) + def validate_title(self, title: Any) -> str: """Make sure the title is set to a string.""" return str(title) @@ -1901,7 +1921,12 @@ def _init_mode(self, mode: str) -> AwaitMount: if stack: await_mount = AwaitMount(stack[0], []) else: - _screen = self.MODES[mode] + _screen = self._modes[mode] + if isinstance(_screen, Screen): + raise TypeError( + "MODES cannot contain instances, use a type instead " + f"(got instance of {type(_screen).__name__} for {mode!r})" + ) new_screen: Screen | str = _screen() if callable(_screen) else _screen screen, await_mount = self._get_screen(new_screen) stack.append(screen) @@ -1923,7 +1948,7 @@ def switch_mode(self, mode: str) -> AwaitMount: Raises: UnknownModeError: If trying to switch to an unknown mode. """ - if mode not in self.MODES: + if mode not in self._modes: raise UnknownModeError(f"No known mode {mode!r}") self.screen.post_message(events.ScreenSuspend()) @@ -1943,9 +1968,7 @@ def switch_mode(self, mode: str) -> AwaitMount: return await_mount - def add_mode( - self, mode: str, base_screen: str | Screen | Callable[[], Screen] - ) -> None: + def add_mode(self, mode: str, base_screen: str | Callable[[], Screen]) -> None: """Adds a mode and its corresponding base screen to the app. Args: @@ -1957,10 +1980,15 @@ def add_mode( """ if mode == "_default": raise InvalidModeError("Cannot use '_default' as a custom mode.") - elif mode in self.MODES: + elif mode in self._modes: raise InvalidModeError(f"Duplicated mode name {mode!r}.") - self.MODES[mode] = base_screen + if isinstance(base_screen, Screen): + raise TypeError( + "add_mode() must be called with a Screen type, not an instance" + f" (got instance of {type(base_screen).__name__})" + ) + self._modes[mode] = base_screen def remove_mode(self, mode: str) -> AwaitComplete: """Removes a mode from the app. @@ -1976,10 +2004,10 @@ def remove_mode(self, mode: str) -> AwaitComplete: """ if mode == self._current_mode: raise ActiveModeError(f"Can't remove active mode {mode!r}") - elif mode not in self.MODES: + elif mode not in self._modes: raise UnknownModeError(f"Unknown mode {mode!r}") else: - del self.MODES[mode] + del self._modes[mode] if mode not in self._screen_stacks: return AwaitComplete.nothing() @@ -2860,11 +2888,6 @@ async def _close_all(self) -> None: await self._prune(stack_screen) stack.clear() - # Close pre-defined screens. - for screen in self.SCREENS.values(): - if isinstance(screen, Screen) and screen._running: - await self._prune(screen) - # Close any remaining nodes # Should be empty by now remaining_nodes = list(self._registry) diff --git a/tests/css/test_screen_css.py b/tests/css/test_screen_css.py index 61d26a6a38..9e6668f6ac 100644 --- a/tests/css/test_screen_css.py +++ b/tests/css/test_screen_css.py @@ -126,7 +126,7 @@ async def test_screen_css_push_screen_instance_by_name(): """Check that screen CSS is loaded and applied when pushing a screen name that points to a screen instance.""" class MyApp(BaseApp): - SCREENS = {"screenwithcss": ScreenWithCSS()} + SCREENS = {"screenwithcss": ScreenWithCSS} def key_p(self): self.push_screen("screenwithcss") @@ -187,7 +187,7 @@ async def test_screen_css_switch_screen_instance_by_name(): """Check that screen CSS is loaded and applied when switching a screen name that points to a screen instance.""" class MyApp(SwitchBaseApp): - SCREENS = {"screenwithcss": ScreenWithCSS()} + SCREENS = {"screenwithcss": ScreenWithCSS} def key_p(self): self.switch_screen("screenwithcss") @@ -230,8 +230,8 @@ async def test_screen_css_switch_mode_screen_instance(): class MyApp(BaseApp): MODES = { - "base": BaseScreen(), - "mode": ScreenWithCSS(), + "base": BaseScreen, + "mode": ScreenWithCSS, } def key_p(self): @@ -255,11 +255,11 @@ async def test_screen_css_switch_mode_screen_instance_by_name(): class MyApp(BaseApp): SCREENS = { - "screenwithcss": ScreenWithCSS(), + "screenwithcss": ScreenWithCSS, } MODES = { - "base": BaseScreen(), + "base": BaseScreen, "mode": "screenwithcss", } @@ -288,7 +288,7 @@ class MyApp(BaseApp): } MODES = { - "base": BaseScreen(), + "base": BaseScreen, "mode": "screenwithcss", } diff --git a/tests/snapshot_tests/snapshot_apps/notification_through_modes.py b/tests/snapshot_tests/snapshot_apps/notification_through_modes.py index 5c0e0ee3e8..e739de44a1 100644 --- a/tests/snapshot_tests/snapshot_apps/notification_through_modes.py +++ b/tests/snapshot_tests/snapshot_apps/notification_through_modes.py @@ -11,7 +11,7 @@ def compose(self) -> ComposeResult: class NotifyThroughModesApp(App[None]): MODES = { - "test": Mode() + "test": Mode } def compose(self) -> ComposeResult: diff --git a/tests/test_screen_modes.py b/tests/test_screen_modes.py index 85d1a8b358..35576c3915 100644 --- a/tests/test_screen_modes.py +++ b/tests/test_screen_modes.py @@ -117,7 +117,7 @@ async def test_remove_mode(ModesApp: Type[App]): await pilot.pause() assert str(app.screen.query_one(Label).renderable) == "two" app.remove_mode("one") - assert "one" not in app.MODES + assert "one" not in app._modes async def test_remove_active_mode(ModesApp: Type[App]): @@ -130,7 +130,7 @@ async def test_remove_active_mode(ModesApp: Type[App]): async def test_add_mode(ModesApp: Type[App]): app = ModesApp() async with app.run_test() as pilot: - app.add_mode("three", BaseScreen("three")) + app.add_mode("three", lambda: BaseScreen("three")) await app.switch_mode("three") await pilot.pause() assert str(app.screen.query_one(Label).renderable) == "three" @@ -140,7 +140,7 @@ async def test_add_mode_duplicated(ModesApp: Type[App]): app = ModesApp() async with app.run_test(): with pytest.raises(InvalidModeError): - app.add_mode("one", BaseScreen("one")) + app.add_mode("one", lambda: BaseScreen("one")) async def test_screen_stack_preserved(ModesApp: Type[App]): diff --git a/tests/test_screens.py b/tests/test_screens.py index 044ca86555..00f3504361 100644 --- a/tests/test_screens.py +++ b/tests/test_screens.py @@ -40,7 +40,7 @@ async def test_installed_screens(): class ScreensApp(App): SCREENS = { "home": Screen, # Screen type - "one": Screen(), # Screen instance + "one": Screen, # Screen instance, disallowed as of #4893 "two": Screen, # Callable[[], Screen] } @@ -354,7 +354,7 @@ class MyScreen(Screen): pass class MyApp(App[None]): - SCREENS = {"screen": MyScreen()} + SCREENS = {"screen": MyScreen} def on_mount(self): self.push_screen("screen") @@ -379,8 +379,8 @@ class ScreenB(Screen): class MyApp(App[None]): SCREENS = { - "a": ScreenA(), - "b": ScreenB(), + "a": ScreenA, + "b": ScreenB, } def callback(self, _): @@ -407,7 +407,7 @@ def on_mouse_move(self, event: MouseMove) -> None: MouseMoveRecordingScreen.mouse_events.append(event) class SimpleApp(App[None]): - SCREENS = {"a": MouseMoveRecordingScreen()} + SCREENS = {"a": MouseMoveRecordingScreen} def on_mount(self): self.push_screen("a") @@ -439,7 +439,7 @@ def on_mouse_move(self, event: MouseMove) -> None: MouseMoveRecordingScreen.mouse_events.append(event) class SimpleApp(App[None]): - SCREENS = {"a": MouseMoveRecordingScreen()} + SCREENS = {"a": MouseMoveRecordingScreen} def on_mount(self): self.push_screen("a") @@ -539,6 +539,35 @@ def get_default_screen(self) -> Screen: assert app.screen is app.screen_stack[0] +async def test_disallow_screen_instances() -> None: + """Test that screen instances are disallowed.""" + + class CustomScreen(Screen): + pass + + with pytest.raises(ValueError): + + class Bad(App): + SCREENS = {"a": CustomScreen()} # type: ignore + + with pytest.raises(ValueError): + + class Worse(App): + MODES = {"a": CustomScreen()} # type: ignore + + # While we're here, let's make sure that other types + # are disallowed. + with pytest.raises(TypeError): + + class Terrible(App): + MODES = {"a": 42, "b": CustomScreen} # type: ignore + + with pytest.raises(TypeError): + + class Worst(App): + MODES = {"OK": CustomScreen, 1: 2} # type: ignore + + async def test_worker_cancellation(): """Regression test for https://github.com/Textualize/textual/issues/4884