diff --git a/src/textual/_context.py b/src/textual/_context.py index ca94f9c868..d28a1b15a2 100644 --- a/src/textual/_context.py +++ b/src/textual/_context.py @@ -1,8 +1,7 @@ from __future__ import annotations -import weakref -from contextvars import ContextVar, Token -from typing import TYPE_CHECKING, Callable, Generic, TypeVar, overload +from contextvars import ContextVar +from typing import TYPE_CHECKING, Any, Callable, TypeVar if TYPE_CHECKING: from .app import App @@ -19,54 +18,8 @@ class NoActiveAppError(RuntimeError): DefaultType = TypeVar("DefaultType") -class ContextDefault: - pass - - -_context_default = ContextDefault() - - -class TextualContextVar(Generic[ContextVarType]): - """Like ContextVar but doesn't hold on to references.""" - - def __init__(self, name: str) -> None: - self._context_var: ContextVar[weakref.ReferenceType[ContextVarType]] = ( - ContextVar(name) - ) - - @overload - def get(self) -> ContextVarType: ... - - @overload - def get(self, default: DefaultType) -> ContextVarType | DefaultType: ... - - def get( - self, default: DefaultType | ContextDefault = _context_default - ) -> ContextVarType | DefaultType: - try: - value_ref = self._context_var.get() - except LookupError: - if isinstance(default, ContextDefault): - raise - return default - value = value_ref() - if value is None: - if isinstance(default, ContextDefault): - raise LookupError(value) - return default - return value - - def set(self, value: ContextVarType) -> object: - return self._context_var.set(weakref.ref(value)) - - def reset(self, token: Token[weakref.ReferenceType[ContextVarType]]) -> None: - self._context_var.reset(token) - - -active_app: TextualContextVar["App[object]"] = TextualContextVar("active_app") -active_message_pump: TextualContextVar["MessagePump"] = TextualContextVar( - "active_message_pump" -) +active_app: ContextVar["App[Any]"] = ContextVar("active_app") +active_message_pump: ContextVar["MessagePump"] = ContextVar("active_message_pump") prevent_message_types_stack: ContextVar[list[set[type[Message]]]] = ContextVar( "prevent_message_types_stack" @@ -75,7 +28,5 @@ def reset(self, token: Token[weakref.ReferenceType[ContextVarType]]) -> None: "visible_screen_stack" ) """A stack of visible screens (with background alpha < 1), used in the screen render process.""" -message_hook: TextualContextVar[Callable[[Message], None]] = TextualContextVar( - "message_hook" -) +message_hook: ContextVar[Callable[[Message], None]] = ContextVar("message_hook") """A callable that accepts a message. Used by App.run_test.""" diff --git a/src/textual/app.py b/src/textual/app.py index 759fac5065..ba0e79bb9f 100644 --- a/src/textual/app.py +++ b/src/textual/app.py @@ -809,6 +809,17 @@ def _end_batch(self) -> None: if not self._batch_count: self.check_idle() + @contextmanager + def _context(self) -> Generator[None, None, None]: + """Context manager to set ContextVars.""" + app_reset_token = active_app.set(self) + message_pump_reset_token = active_message_pump.set(self) + try: + yield + finally: + active_message_pump.reset(message_pump_reset_token) + active_app.reset(app_reset_token) + def animate( self, attribute: str, @@ -1046,10 +1057,6 @@ def get_default_screen(self) -> Screen: """ return Screen(id="_default") - def _set_active(self) -> None: - """Set this app to be the currently active app.""" - active_app.set(self) - def compose(self) -> ComposeResult: """Yield child widgets for a container. @@ -1355,8 +1362,8 @@ def call_from_thread( async def run_callback() -> CallThreadReturnType: """Run the callback, set the result or error on the future.""" - self._set_active() - return await invoke(callback_with_args) + with self._context(): + return await invoke(callback_with_args) # Post the message to the main loop future: Future[CallThreadReturnType] = asyncio.run_coroutine_threadsafe( @@ -1667,41 +1674,39 @@ async def run_app(app: App) -> None: app: App to run. """ - try: - if message_hook is not None: - message_hook_context_var.set(message_hook) - app._loop = asyncio.get_running_loop() - app._thread_id = threading.get_ident() - await app._process_messages( - ready_callback=on_app_ready, - headless=headless, - terminal_size=size, - ) - finally: - app_ready_event.set() + with app._context(): + try: + if message_hook is not None: + message_hook_context_var.set(message_hook) + app._loop = asyncio.get_running_loop() + app._thread_id = threading.get_ident() + await app._process_messages( + ready_callback=on_app_ready, + headless=headless, + terminal_size=size, + ) + finally: + app_ready_event.set() # Launch the app in the "background" - active_message_pump.set(app) + app_task = create_task(run_app(app), name=f"run_test {app}") # Wait until the app has performed all startup routines. await app_ready_event.wait() - - # Get the app in an active state. - app._set_active() - - # Context manager returns pilot object to manipulate the app - try: - pilot = Pilot(app) - await pilot._wait_for_screen() - yield pilot - finally: - # Shutdown the app cleanly - await app._shutdown() - await app_task - # Re-raise the exception which caused panic so test frameworks are aware - if self._exception: - raise self._exception + with app._context(): + # Context manager returns pilot object to manipulate the app + try: + pilot = Pilot(app) + await pilot._wait_for_screen() + yield pilot + finally: + # Shutdown the app cleanly + await app._shutdown() + await app_task + # Re-raise the exception which caused panic so test frameworks are aware + if self._exception: + raise self._exception async def run_async( self, @@ -1751,14 +1756,14 @@ async def app_ready() -> None: async def run_auto_pilot( auto_pilot: AutopilotCallbackType, pilot: Pilot ) -> None: - try: - await auto_pilot(pilot) - except Exception: - app.exit() - raise + with self._context(): + try: + await auto_pilot(pilot) + except Exception: + app.exit() + raise pilot = Pilot(app) - active_message_pump.set(self) auto_pilot_task = create_task( run_auto_pilot(auto_pilot, pilot), name=repr(pilot) ) @@ -1816,18 +1821,19 @@ async def run_app() -> None: """Run the app.""" self._loop = asyncio.get_running_loop() self._thread_id = threading.get_ident() - try: - await self.run_async( - headless=headless, - inline=inline, - inline_no_clear=inline_no_clear, - mouse=mouse, - size=size, - auto_pilot=auto_pilot, - ) - finally: - self._loop = None - self._thread_id = 0 + with self._context(): + try: + await self.run_async( + headless=headless, + inline=inline, + inline_no_clear=inline_no_clear, + mouse=mouse, + size=size, + auto_pilot=auto_pilot, + ) + finally: + self._loop = None + self._thread_id = 0 if _ASYNCIO_GET_EVENT_LOOP_IS_DEPRECATED: # N.B. This doesn't work with Python<3.10, as we end up with 2 event loops: @@ -2680,6 +2686,16 @@ def _build_driver( ) return driver + async def _init_devtools(self): + if self.devtools is not None: + from textual_dev.client import DevtoolsConnectionError + + try: + await self.devtools.connect() + self.log.system(f"Connected to devtools ( {self.devtools.url} )") + except DevtoolsConnectionError: + self.log.system(f"Couldn't connect to devtools ( {self.devtools.url} )") + async def _process_messages( self, ready_callback: CallbackType | None = None, @@ -2690,54 +2706,44 @@ async def _process_messages( terminal_size: tuple[int, int] | None = None, message_hook: Callable[[Message], None] | None = None, ) -> None: - self._set_active() - active_message_pump.set(self) - - if self.devtools is not None: - from textual_dev.client import DevtoolsConnectionError + async def app_prelude() -> bool: + await self._init_devtools() + self.log.system("---") + self.log.system(loop=asyncio.get_running_loop()) + self.log.system(features=self.features) + if constants.LOG_FILE is not None: + _log_path = os.path.abspath(constants.LOG_FILE) + self.log.system(f"Writing logs to {_log_path!r}") try: - await self.devtools.connect() - self.log.system(f"Connected to devtools ( {self.devtools.url} )") - except DevtoolsConnectionError: - self.log.system(f"Couldn't connect to devtools ( {self.devtools.url} )") - - self.log.system("---") - - self.log.system(loop=asyncio.get_running_loop()) - self.log.system(features=self.features) - if constants.LOG_FILE is not None: - _log_path = os.path.abspath(constants.LOG_FILE) - self.log.system(f"Writing logs to {_log_path!r}") - - try: - if self.css_path: - self.stylesheet.read_all(self.css_path) - for read_from, css, tie_breaker, scope in self._get_default_css(): - self.stylesheet.add_source( - css, - read_from=read_from, - is_default_css=True, - tie_breaker=tie_breaker, - scope=scope, - ) - if self.CSS: - try: - app_path = inspect.getfile(self.__class__) - except (TypeError, OSError): - app_path = "" - read_from = (app_path, f"{self.__class__.__name__}.CSS") - self.stylesheet.add_source( - self.CSS, read_from=read_from, is_default_css=False - ) - except Exception as error: - self._handle_exception(error) - self._print_error_renderables() - return + if self.css_path: + self.stylesheet.read_all(self.css_path) + for read_from, css, tie_breaker, scope in self._get_default_css(): + self.stylesheet.add_source( + css, + read_from=read_from, + is_default_css=True, + tie_breaker=tie_breaker, + scope=scope, + ) + if self.CSS: + try: + app_path = inspect.getfile(self.__class__) + except (TypeError, OSError): + app_path = "" + read_from = (app_path, f"{self.__class__.__name__}.CSS") + self.stylesheet.add_source( + self.CSS, read_from=read_from, is_default_css=False + ) + except Exception as error: + self._handle_exception(error) + self._print_error_renderables() + return False - if self.css_monitor: - self.set_interval(0.25, self.css_monitor, name="css monitor") - self.log.system("STARTED", self.css_monitor) + if self.css_monitor: + self.set_interval(0.25, self.css_monitor, name="css monitor") + self.log.system("STARTED", self.css_monitor) + return True async def run_process_messages(): """The main message loop, invoke below.""" @@ -2788,42 +2794,45 @@ async def invoke_ready_callback() -> None: finally: await Timer._stop_all(self._timers) - self._running = True - try: - load_event = events.Load() - await self._dispatch_message(load_event) + with self._context(): + if not await app_prelude(): + return + self._running = True + try: + load_event = events.Load() + await self._dispatch_message(load_event) - driver = self._driver = self._build_driver( - headless=headless, - inline=inline, - mouse=mouse, - size=terminal_size, - ) - self.log(driver=driver) + driver = self._driver = self._build_driver( + headless=headless, + inline=inline, + mouse=mouse, + size=terminal_size, + ) + self.log(driver=driver) - if not self._exit: - driver.start_application_mode() - try: - with redirect_stdout(self._capture_stdout): - with redirect_stderr(self._capture_stderr): - await run_process_messages() + if not self._exit: + driver.start_application_mode() + try: + with redirect_stdout(self._capture_stdout): + with redirect_stderr(self._capture_stderr): + await run_process_messages() - finally: - if hasattr(self, "_watchers"): - self._watchers.clear() - if self._driver.is_inline: - cursor_x, cursor_y = self._previous_cursor_position - self._driver.write( - Control.move(-cursor_x, -cursor_y + 1).segment.text - ) - if inline_no_clear and not not self.app._exit_renderables: - console = Console() - console.print(self.screen._compositor) - console.print() + finally: + if hasattr(self, "_watchers"): + self._watchers.clear() + if self._driver.is_inline: + cursor_x, cursor_y = self._previous_cursor_position + self._driver.write( + Control.move(-cursor_x, -cursor_y + 1).segment.text + ) + if inline_no_clear and not not self.app._exit_renderables: + console = Console() + console.print(self.screen._compositor) + console.print() - driver.stop_application_mode() - except Exception as error: - self._handle_exception(error) + driver.stop_application_mode() + except Exception as error: + self._handle_exception(error) async def _pre_process(self) -> bool: """Special case for the app, which doesn't need the functionality in MessagePump.""" @@ -3049,7 +3058,6 @@ async def _shutdown(self) -> None: await self._close_all() await self._close_messages() - await self._dispatch_message(events.Unmount()) if self._driver is not None: diff --git a/src/textual/command.py b/src/textual/command.py index 54bc5a4c73..55789e0bb5 100644 --- a/src/textual/command.py +++ b/src/textual/command.py @@ -241,6 +241,7 @@ async def _wait_init(self) -> None: """Wait for initialization.""" if self._init_task is not None: await self._init_task + self._init_task = None async def startup(self) -> None: """Called after the Provider is initialized, but before any calls to `search`.""" diff --git a/src/textual/message_pump.py b/src/textual/message_pump.py index 8a58d6c752..2e0316bb3b 100644 --- a/src/textual/message_pump.py +++ b/src/textual/message_pump.py @@ -229,7 +229,7 @@ def app(self) -> "App[object]": if node is None: raise NoActiveAppError() node = node._parent - active_app.set(node) + return node @property @@ -501,26 +501,27 @@ def _start_messages(self) -> None: async def _process_messages(self) -> None: self._running = True - active_message_pump.set(self) - if not await self._pre_process(): - self._running = False - return + with self._context(): + if not await self._pre_process(): + self._running = False + return - try: - await self._process_messages_loop() - except CancelledError: - pass - finally: - self._running = False try: - if self._timers: - await Timer._stop_all(self._timers) - self._timers.clear() + await self._process_messages_loop() + except CancelledError: + pass finally: - if hasattr(self, "_watchers"): - self._watchers.clear() - await self._message_loop_exit() + self._running = False + try: + if self._timers: + await Timer._stop_all(self._timers) + self._timers.clear() + if hasattr(self, "_watchers"): + self._watchers.clear() + finally: + await self._message_loop_exit() + self._task = None async def _message_loop_exit(self) -> None: """Called when the message loop has completed.""" @@ -560,6 +561,15 @@ def _close_messages_no_wait(self) -> None: """Request the message queue to immediately exit.""" self._message_queue.put_nowait(messages.CloseMessages()) + @contextmanager + def _context(self) -> Generator[None, None, None]: + """Context manager to set ContextVars.""" + reset_token = active_message_pump.set(self) + try: + yield + finally: + active_message_pump.reset(reset_token) + async def _on_close_messages(self, message: messages.CloseMessages) -> None: await self._close_messages() diff --git a/src/textual/reactive.py b/src/textual/reactive.py index 162b4bd3f5..4fa8deb59d 100644 --- a/src/textual/reactive.py +++ b/src/textual/reactive.py @@ -24,7 +24,6 @@ from . import events from ._callback import count_parameters -from ._context import active_message_pump from ._types import ( MessageTarget, WatchCallbackBothValuesType, @@ -82,8 +81,8 @@ def invoke_watcher( _rich_traceback_omit = True param_count = count_parameters(watch_function) - reset_token = active_message_pump.set(watcher_object) - try: + + with watcher_object._context(): if param_count == 2: watch_result = cast(WatchCallbackBothValuesType, watch_function)( old_value, value @@ -97,8 +96,6 @@ def invoke_watcher( watcher_object.call_next( partial(await_watcher, watcher_object, watch_result) ) - finally: - active_message_pump.reset(reset_token) @rich.repr.auto @@ -203,7 +200,7 @@ def _reset_object(cls, obj: object) -> None: Args: obj: A reactive object. """ - getattr(obj, "__watchers", {}).clear() + getattr(obj, "_watchers", {}).clear() getattr(obj, "__computes", []).clear() def __set_name__(self, owner: Type[MessageTarget], name: str) -> None: @@ -351,7 +348,7 @@ def _check_watchers(cls, obj: Reactable, name: str, old_value: Any) -> None: # Process "global" watchers watchers: list[tuple[Reactable, WatchCallbackType]] - watchers = getattr(obj, "__watchers", {}).get(name, []) + watchers = getattr(obj, "_watchers", {}).get(name, []) # Remove any watchers for reactables that have since closed if watchers: watchers[:] = [ diff --git a/src/textual/screen.py b/src/textual/screen.py index 5affe749f6..4ea8eaba26 100644 --- a/src/textual/screen.py +++ b/src/textual/screen.py @@ -979,11 +979,8 @@ async def _invoke_and_clear_callbacks(self) -> None: callbacks = self._callbacks[:] self._callbacks.clear() for callback, message_pump in callbacks: - reset_token = active_message_pump.set(message_pump) - try: + with message_pump._context(): await invoke(callback) - finally: - active_message_pump.reset(reset_token) def _invoke_later(self, callback: CallbackType, sender: MessagePump) -> None: """Enqueue a callback to be invoked after the screen is repainted. @@ -1014,12 +1011,14 @@ def _push_result_callback( ) async def _message_loop_exit(self) -> None: + await super()._message_loop_exit() self._compositor.clear() self._dirty_widgets.clear() self._dirty_regions.clear() self._arrangement_cache.clear() self.screen_layout_refresh_signal.unsubscribe(self) self._nodes._clear() + self._task = None def _pop_result_callback(self) -> None: """Remove the latest result callback from the stack.""" diff --git a/src/textual/timer.py b/src/textual/timer.py index 710b684b6b..e593203ee3 100644 --- a/src/textual/timer.py +++ b/src/textual/timer.py @@ -97,7 +97,6 @@ def stop(self) -> None: self._active.set() self._task.cancel() self._task = None - return self._task @classmethod async def _stop_all(cls, timers: Iterable[Timer]) -> None: diff --git a/src/textual/widget.py b/src/textual/widget.py index e0765560d0..58305d0497 100644 --- a/src/textual/widget.py +++ b/src/textual/widget.py @@ -49,7 +49,7 @@ from ._animator import DEFAULT_EASING, Animatable, BoundAnimator, EasingFunction from ._arrange import DockArrangeResult, arrange from ._compose import compose -from ._context import NoActiveAppError, active_app +from ._context import NoActiveAppError from ._debug import get_caller_file_and_line from ._dispatch_key import dispatch_key from ._easing import DEFAULT_SCROLL_EASING @@ -1197,7 +1197,6 @@ async def recompose(self) -> None: await self.query_children("*").exclude(".-textual-system").remove() if self.is_attached: compose_nodes = compose(self) - print("COMPOSE", compose_nodes) await self.mount_all(compose_nodes) def _post_register(self, app: App) -> None: @@ -1885,7 +1884,7 @@ def _console(self) -> Console: Returns: A Rich console object. """ - return active_app.get().console + return self.app.console @property def _has_relative_children_width(self) -> bool: diff --git a/src/textual/widgets/_footer.py b/src/textual/widgets/_footer.py index df6b0e0176..98214a56f7 100644 --- a/src/textual/widgets/_footer.py +++ b/src/textual/widgets/_footer.py @@ -96,9 +96,6 @@ def __init__( if tooltip: self.tooltip = tooltip - def __repr__(self) -> str: - return f"FooterKey({self._parent!r})" - def render(self) -> Text: key_style = self.get_component_rich_style("footer-key--key") description_style = self.get_component_rich_style("footer-key--description") diff --git a/src/textual/worker.py b/src/textual/worker.py index 9ad60a64ac..c0b1cbbffd 100644 --- a/src/textual/worker.py +++ b/src/textual/worker.py @@ -359,30 +359,30 @@ async def _run(self, app: App) -> None: Args: app: App instance. """ - app._set_active() - active_worker.set(self) + with app._context(): + active_worker.set(self) - self.state = WorkerState.RUNNING - app.log.worker(self) - try: - self._result = await self.run() - except asyncio.CancelledError as error: - self.state = WorkerState.CANCELLED - self._error = error - app.log.worker(self) - except Exception as error: - self.state = WorkerState.ERROR - self._error = error - app.log.worker(self, "failed", repr(error)) - from rich.traceback import Traceback - - app.log.worker(Traceback()) - if self.exit_on_error: - worker_failed = WorkerFailed(self._error) - app._handle_exception(worker_failed) - else: - self.state = WorkerState.SUCCESS + self.state = WorkerState.RUNNING app.log.worker(self) + try: + self._result = await self.run() + except asyncio.CancelledError as error: + self.state = WorkerState.CANCELLED + self._error = error + app.log.worker(self) + except Exception as error: + self.state = WorkerState.ERROR + self._error = error + app.log.worker(self, "failed", repr(error)) + from rich.traceback import Traceback + + app.log.worker(Traceback()) + if self.exit_on_error: + worker_failed = WorkerFailed(self._error) + app._handle_exception(worker_failed) + else: + self.state = WorkerState.SUCCESS + app.log.worker(self) def _start( self, app: App, done_callback: Callable[[Worker], None] | None = None diff --git a/tests/layouts/test_horizontal.py b/tests/layouts/test_horizontal.py index 05ce4c8905..62d766aecb 100644 --- a/tests/layouts/test_horizontal.py +++ b/tests/layouts/test_horizontal.py @@ -19,11 +19,11 @@ def compose(self) -> ComposeResult: yield self.horizontal app = HorizontalAutoWidth() - async with app.run_test(): - yield app + yield app async def test_horizontal_get_content_width(app): - size = app.screen.size - width = app.horizontal.get_content_width(size, size) - assert width == 15 + async with app.run_test(): + size = app.screen.size + width = app.horizontal.get_content_width(size, size) + assert width == 15 diff --git a/tests/test_focus.py b/tests/test_focus.py index ee955e23e5..a23046bc94 100644 --- a/tests/test_focus.py +++ b/tests/test_focus.py @@ -22,46 +22,47 @@ class ChildrenFocusableOnly(Widget, can_focus=False, can_focus_children=True): @pytest.fixture def screen() -> Screen: app = App() - app._set_active() - app.push_screen(Screen()) - screen = app.screen + with app._context(): + app.push_screen(Screen()) - # The classes even/odd alternate along the focus chain. - # The classes in/out identify nested widgets. - screen._add_children( - Focusable(id="foo", classes="a"), - NonFocusable(id="bar"), - Focusable(Focusable(id="Paul", classes="c"), id="container1", classes="b"), - NonFocusable(Focusable(id="Jessica", classes="a"), id="container2"), - Focusable(id="baz", classes="b"), - ChildrenFocusableOnly(Focusable(id="child", classes="c")), - ) + screen = app.screen - return screen + # The classes even/odd alternate along the focus chain. + # The classes in/out identify nested widgets. + screen._add_children( + Focusable(id="foo", classes="a"), + NonFocusable(id="bar"), + Focusable(Focusable(id="Paul", classes="c"), id="container1", classes="b"), + NonFocusable(Focusable(id="Jessica", classes="a"), id="container2"), + Focusable(id="baz", classes="b"), + ChildrenFocusableOnly(Focusable(id="child", classes="c")), + ) + + return screen def test_focus_chain(): app = App() - app._set_active() - app.push_screen(Screen()) + with app._context(): + app.push_screen(Screen()) - screen = app.screen + screen = app.screen - # Check empty focus chain - assert not screen.focus_chain + # Check empty focus chain + assert not screen.focus_chain - app.screen._add_children( - Focusable(id="foo"), - NonFocusable(id="bar"), - Focusable(Focusable(id="Paul"), id="container1"), - NonFocusable(Focusable(id="Jessica"), id="container2"), - Focusable(id="baz"), - ChildrenFocusableOnly(Focusable(id="child")), - ) + app.screen._add_children( + Focusable(id="foo"), + NonFocusable(id="bar"), + Focusable(Focusable(id="Paul"), id="container1"), + NonFocusable(Focusable(id="Jessica"), id="container2"), + Focusable(id="baz"), + ChildrenFocusableOnly(Focusable(id="child")), + ) - focus_chain = [widget.id for widget in screen.focus_chain] - assert focus_chain == ["foo", "container1", "Paul", "baz", "child"] + focus_chain = [widget.id for widget in screen.focus_chain] + assert focus_chain == ["foo", "container1", "Paul", "baz", "child"] def test_allow_focus(): @@ -90,18 +91,19 @@ def allow_focus_children(self) -> bool: return False app = App() - app._set_active() - app.push_screen(Screen()) - app.screen._add_children( - Focusable(id="foo"), - NonFocusable(id="bar"), - FocusableContainer(Button("egg", id="egg")), - NonFocusableContainer(Button("EGG", id="qux")), - ) - assert [widget.id for widget in app.screen.focus_chain] == ["foo", "egg"] - assert focusable_allow_focus_called - assert non_focusable_allow_focus_called + with app._context(): + app.push_screen(Screen()) + + app.screen._add_children( + Focusable(id="foo"), + NonFocusable(id="bar"), + FocusableContainer(Button("egg", id="egg")), + NonFocusableContainer(Button("EGG", id="qux")), + ) + assert [widget.id for widget in app.screen.focus_chain] == ["foo", "egg"] + assert focusable_allow_focus_called + assert non_focusable_allow_focus_called def test_focus_next_and_previous(screen: Screen): @@ -188,47 +190,47 @@ def test_focus_next_and_previous_with_str_selector(screen: Screen): def test_focus_next_and_previous_with_type_selector_without_self(): """Test moving the focus with a selector that does not match the currently focused node.""" app = App() - app._set_active() - app.push_screen(Screen()) - - screen = app.screen - - from textual.containers import Horizontal, VerticalScroll - from textual.widgets import Button, Input, Switch - - screen._add_children( - VerticalScroll( - Horizontal( - Input(id="w3"), - Switch(id="w4"), - Input(id="w5"), - Button(id="w6"), - Switch(id="w7"), - id="w2", - ), - Horizontal( - Button(id="w9"), - Switch(id="w10"), - Button(id="w11"), - Input(id="w12"), - Input(id="w13"), - id="w8", - ), - id="w1", + with app._context(): + app.push_screen(Screen()) + + screen = app.screen + + from textual.containers import Horizontal, VerticalScroll + from textual.widgets import Button, Input, Switch + + screen._add_children( + VerticalScroll( + Horizontal( + Input(id="w3"), + Switch(id="w4"), + Input(id="w5"), + Button(id="w6"), + Switch(id="w7"), + id="w2", + ), + Horizontal( + Button(id="w9"), + Switch(id="w10"), + Button(id="w11"), + Input(id="w12"), + Input(id="w13"), + id="w8", + ), + id="w1", + ) ) - ) - screen.set_focus(screen.query_one("#w3")) - assert screen.focused.id == "w3" + screen.set_focus(screen.query_one("#w3")) + assert screen.focused.id == "w3" - assert screen.focus_next(Button).id == "w6" - assert screen.focus_next(Switch).id == "w7" - assert screen.focus_next(Input).id == "w12" + assert screen.focus_next(Button).id == "w6" + assert screen.focus_next(Switch).id == "w7" + assert screen.focus_next(Input).id == "w12" - assert screen.focus_previous(Button).id == "w11" - assert screen.focus_previous(Switch).id == "w10" - assert screen.focus_previous(Button).id == "w9" - assert screen.focus_previous(Input).id == "w5" + assert screen.focus_previous(Button).id == "w11" + assert screen.focus_previous(Switch).id == "w10" + assert screen.focus_previous(Button).id == "w9" + assert screen.focus_previous(Input).id == "w5" def test_focus_next_and_previous_with_str_selector_without_self(screen: Screen): diff --git a/tests/test_path.py b/tests/test_path.py index d7088f8be7..3d5203a6e8 100644 --- a/tests/test_path.py +++ b/tests/test_path.py @@ -30,14 +30,15 @@ class ListPathApp(App[None]): @pytest.mark.parametrize( - "app,expected_css_path_attribute", + "app_class,expected_css_path_attribute", [ - (RelativePathObjectApp(), [APP_DIR / "test.tcss"]), - (RelativePathStrApp(), [APP_DIR / "test.tcss"]), - (AbsolutePathObjectApp(), [Path("/tmp/test.tcss")]), - (AbsolutePathStrApp(), [Path("/tmp/test.tcss")]), - (ListPathApp(), [APP_DIR / "test.tcss", Path("/another/path.tcss")]), + (RelativePathObjectApp, [APP_DIR / "test.tcss"]), + (RelativePathStrApp, [APP_DIR / "test.tcss"]), + (AbsolutePathObjectApp, [Path("/tmp/test.tcss")]), + (AbsolutePathStrApp, [Path("/tmp/test.tcss")]), + (ListPathApp, [APP_DIR / "test.tcss", Path("/another/path.tcss")]), ], ) -def test_css_paths_of_various_types(app, expected_css_path_attribute): +def test_css_paths_of_various_types(app_class, expected_css_path_attribute): + app = app_class() assert app.css_path == [path.absolute() for path in expected_css_path_attribute] diff --git a/tests/test_screens.py b/tests/test_screens.py index 00f3504361..a53abb7049 100644 --- a/tests/test_screens.py +++ b/tests/test_screens.py @@ -74,89 +74,88 @@ async def test_screens(): # There should be nothing in the children since the app hasn't run yet assert not app._nodes assert not app.children - app._set_active() - - with pytest.raises(ScreenStackError): - app.screen - - assert not app._installed_screens - - screen1 = Screen(name="screen1") - screen2 = Screen(name="screen2") - screen3 = Screen(name="screen3") - - # installs screens - app.install_screen(screen1, "screen1") - app.install_screen(screen2, "screen2") - - # Installing a screen does not add it to the DOM - assert not app._nodes - assert not app.children + with app._context(): + with pytest.raises(ScreenStackError): + app.screen + + assert not app._installed_screens + + screen1 = Screen(name="screen1") + screen2 = Screen(name="screen2") + screen3 = Screen(name="screen3") + + # installs screens + app.install_screen(screen1, "screen1") + app.install_screen(screen2, "screen2") + + # Installing a screen does not add it to the DOM + assert not app._nodes + assert not app.children + + # Check they are installed + assert app.is_screen_installed("screen1") + assert app.is_screen_installed("screen2") + + assert app.get_screen("screen1") is screen1 + with pytest.raises(KeyError): + app.get_screen("foo") + + # Check screen3 is not installed + assert not app.is_screen_installed("screen3") + + # Installs screen3 + app.install_screen(screen3, "screen3") + # Confirm installed + assert app.is_screen_installed("screen3") + + # Check screen stack is empty + assert app.screen_stack == [] + # Push a screen + await app.push_screen("screen1") + # Check it is on the stack + assert app.screen_stack == [screen1] + # Check it is current + assert app.screen is screen1 + # There should be one item in the children view + assert app.children == (screen1,) + + # Switch to another screen + await app.switch_screen("screen2") + # Check it has changed the stack and that it is current + assert app.screen_stack == [screen2] + assert app.screen is screen2 + assert app.children == (screen2,) + + # Push another screen + await app.push_screen("screen3") + assert app.screen_stack == [screen2, screen3] + assert app.screen is screen3 + # Only the current screen is in children + assert app.children == (screen3,) + + # Pop a screen + await app.pop_screen() + assert app.screen is screen2 + assert app.screen_stack == [screen2] + + # Uninstall screens + app.uninstall_screen(screen1) + assert not app.is_screen_installed(screen1) + app.uninstall_screen("screen3") + assert not app.is_screen_installed(screen1) + + # Check we can't uninstall a screen on the stack + with pytest.raises(ScreenStackError): + app.uninstall_screen(screen2) - # Check they are installed - assert app.is_screen_installed("screen1") - assert app.is_screen_installed("screen2") - - assert app.get_screen("screen1") is screen1 - with pytest.raises(KeyError): - app.get_screen("foo") - - # Check screen3 is not installed - assert not app.is_screen_installed("screen3") - - # Installs screen3 - app.install_screen(screen3, "screen3") - # Confirm installed - assert app.is_screen_installed("screen3") - - # Check screen stack is empty - assert app.screen_stack == [] - # Push a screen - await app.push_screen("screen1") - # Check it is on the stack - assert app.screen_stack == [screen1] - # Check it is current - assert app.screen is screen1 - # There should be one item in the children view - assert app.children == (screen1,) - - # Switch to another screen - await app.switch_screen("screen2") - # Check it has changed the stack and that it is current - assert app.screen_stack == [screen2] - assert app.screen is screen2 - assert app.children == (screen2,) - - # Push another screen - await app.push_screen("screen3") - assert app.screen_stack == [screen2, screen3] - assert app.screen is screen3 - # Only the current screen is in children - assert app.children == (screen3,) - - # Pop a screen - await app.pop_screen() - assert app.screen is screen2 - assert app.screen_stack == [screen2] - - # Uninstall screens - app.uninstall_screen(screen1) - assert not app.is_screen_installed(screen1) - app.uninstall_screen("screen3") - assert not app.is_screen_installed(screen1) - - # Check we can't uninstall a screen on the stack - with pytest.raises(ScreenStackError): - app.uninstall_screen(screen2) - - # Check we can't pop last screen - with pytest.raises(ScreenStackError): - app.pop_screen() + # Check we can't pop last screen + with pytest.raises(ScreenStackError): + app.pop_screen() - screen1.remove() - screen2.remove() - screen3.remove() - await app._shutdown() + screen1.remove() + screen2.remove() + screen3.remove() + await app._shutdown() async def test_auto_focus_on_screen_if_app_auto_focus_is_none(): diff --git a/tests/test_unmount.py b/tests/test_unmount.py index 324a3812a3..36c0acbeb1 100644 --- a/tests/test_unmount.py +++ b/tests/test_unmount.py @@ -50,4 +50,7 @@ async def on_mount(self) -> None: "MyScreen#main", ] + print(unmount_ids) + print(expected) + assert unmount_ids == expected diff --git a/tests/test_widget.py b/tests/test_widget.py index 2652aec54d..6643916bbb 100644 --- a/tests/test_widget.py +++ b/tests/test_widget.py @@ -57,22 +57,21 @@ def render(self) -> str: widget3 = TextWidget("foo\nbar\nbaz", id="widget3") app = App() - app._set_active() + with app._context(): + width = widget1.get_content_width(Size(20, 20), Size(80, 24)) + height = widget1.get_content_height(Size(20, 20), Size(80, 24), width) + assert width == 3 + assert height == 1 - width = widget1.get_content_width(Size(20, 20), Size(80, 24)) - height = widget1.get_content_height(Size(20, 20), Size(80, 24), width) - assert width == 3 - assert height == 1 + width = widget2.get_content_width(Size(20, 20), Size(80, 24)) + height = widget2.get_content_height(Size(20, 20), Size(80, 24), width) + assert width == 3 + assert height == 2 - width = widget2.get_content_width(Size(20, 20), Size(80, 24)) - height = widget2.get_content_height(Size(20, 20), Size(80, 24), width) - assert width == 3 - assert height == 2 - - width = widget3.get_content_width(Size(20, 20), Size(80, 24)) - height = widget3.get_content_height(Size(20, 20), Size(80, 24), width) - assert width == 3 - assert height == 3 + width = widget3.get_content_width(Size(20, 20), Size(80, 24)) + height = widget3.get_content_height(Size(20, 20), Size(80, 24), width) + assert width == 3 + assert height == 3 class GetByIdApp(App): @@ -87,34 +86,38 @@ def compose(self) -> ComposeResult: id="parent", ) + @property + def parent(self) -> Widget: + return self.query_one("#parent") + @pytest.fixture async def hierarchy_app(): app = GetByIdApp() - async with app.run_test(): - yield app - - -@pytest.fixture -async def parent(hierarchy_app): - yield hierarchy_app.get_widget_by_id("parent") + yield app -def test_get_child_by_id_gets_first_child(parent): - child = parent.get_child_by_id(id="child1") - assert child.id == "child1" - assert child.get_child_by_id(id="grandchild1").id == "grandchild1" - assert parent.get_child_by_id(id="child2").id == "child2" +async def test_get_child_by_id_gets_first_child(hierarchy_app): + async with hierarchy_app.run_test(): + parent = hierarchy_app.parent + child = parent.get_child_by_id(id="child1") + assert child.id == "child1" + assert child.get_child_by_id(id="grandchild1").id == "grandchild1" + assert parent.get_child_by_id(id="child2").id == "child2" -def test_get_child_by_id_no_matching_child(parent): - with pytest.raises(NoMatches): - parent.get_child_by_id(id="doesnt-exist") +async def test_get_child_by_id_no_matching_child(hierarchy_app): + async with hierarchy_app.run_test() as pilot: + parent = pilot.app.parent + with pytest.raises(NoMatches): + parent.get_child_by_id(id="doesnt-exist") -def test_get_child_by_id_only_immediate_descendents(parent): - with pytest.raises(NoMatches): - parent.get_child_by_id(id="grandchild1") +async def test_get_child_by_id_only_immediate_descendents(hierarchy_app): + async with hierarchy_app.run_test() as pilot: + parent = pilot.app.parent + with pytest.raises(NoMatches): + parent.get_child_by_id(id="grandchild1") async def test_get_child_by_type(): @@ -135,51 +138,65 @@ def compose(self) -> ComposeResult: app.get_child_by_type(Label) -def test_get_widget_by_id_no_matching_child(parent): - with pytest.raises(NoMatches): - parent.get_widget_by_id(id="i-dont-exist") +async def test_get_widget_by_id_no_matching_child(hierarchy_app): + async with hierarchy_app.run_test() as pilot: + parent = pilot.app.parent + with pytest.raises(NoMatches): + parent.get_widget_by_id(id="i-dont-exist") -def test_get_widget_by_id_non_immediate_descendants(parent): - result = parent.get_widget_by_id("grandchild1") - assert result.id == "grandchild1" +async def test_get_widget_by_id_non_immediate_descendants(hierarchy_app): + async with hierarchy_app.run_test() as pilot: + parent = pilot.app.parent + result = parent.get_widget_by_id("grandchild1") + assert result.id == "grandchild1" -def test_get_widget_by_id_immediate_descendants(parent): - result = parent.get_widget_by_id("child1") - assert result.id == "child1" +async def test_get_widget_by_id_immediate_descendants(hierarchy_app): + async with hierarchy_app.run_test() as pilot: + parent = pilot.app.parent + result = parent.get_widget_by_id("child1") + assert result.id == "child1" -def test_get_widget_by_id_doesnt_return_self(parent): - with pytest.raises(NoMatches): - parent.get_widget_by_id("parent") +async def test_get_widget_by_id_doesnt_return_self(hierarchy_app): + async with hierarchy_app.run_test() as pilot: + parent = pilot.app.parent + with pytest.raises(NoMatches): + parent.get_widget_by_id("parent") -def test_get_widgets_app_delegated(hierarchy_app, parent): +async def test_get_widgets_app_delegated(hierarchy_app): # Check that get_child_by_id finds the parent, which is a child of the default Screen - queried_parent = hierarchy_app.get_child_by_id("parent") - assert queried_parent is parent + async with hierarchy_app.run_test() as pilot: + parent = pilot.app.parent + queried_parent = hierarchy_app.get_child_by_id("parent") + assert queried_parent is parent - # Check that the grandchild (descendant of the default screen) is found - grandchild = hierarchy_app.get_widget_by_id("grandchild1") - assert grandchild.id == "grandchild1" + # Check that the grandchild (descendant of the default screen) is found + grandchild = hierarchy_app.get_widget_by_id("grandchild1") + assert grandchild.id == "grandchild1" -def test_widget_mount_ids_must_be_unique_mounting_all_in_one_go(parent): - widget1 = Widget(id="hello") - widget2 = Widget(id="hello") +async def test_widget_mount_ids_must_be_unique_mounting_all_in_one_go(hierarchy_app): + async with hierarchy_app.run_test() as pilot: + parent = pilot.app.parent + widget1 = Widget(id="hello") + widget2 = Widget(id="hello") - with pytest.raises(MountError): - parent.mount(widget1, widget2) + with pytest.raises(MountError): + parent.mount(widget1, widget2) -def test_widget_mount_ids_must_be_unique_mounting_multiple_calls(parent): - widget1 = Widget(id="hello") - widget2 = Widget(id="hello") +async def test_widget_mount_ids_must_be_unique_mounting_multiple_calls(hierarchy_app): + async with hierarchy_app.run_test() as pilot: + parent = pilot.app.parent + widget1 = Widget(id="hello") + widget2 = Widget(id="hello") - parent.mount(widget1) - with pytest.raises(DuplicateIds): - parent.mount(widget2) + parent.mount(widget1) + with pytest.raises(DuplicateIds): + parent.mount(widget2) def test_get_pseudo_class_state(): diff --git a/tests/test_widget_mounting.py b/tests/test_widget_mounting.py index 7babed4781..79f7364ad6 100644 --- a/tests/test_widget_mounting.py +++ b/tests/test_widget_mounting.py @@ -116,8 +116,10 @@ async def test_mount_via_app() -> None: await pilot.app.mount(Static(), before="Static") -def test_mount_error() -> None: +async def test_mount_error() -> None: """Mounting a widget on an un-mounted widget should raise an error.""" - with pytest.raises(MountError): - widget = Widget() - widget.mount(Static()) + app = App() + async with app.run_test(): + with pytest.raises(MountError): + widget = Widget() + widget.mount(Static()) diff --git a/tests/text_area/test_history.py b/tests/text_area/test_history.py index 8d50a63f83..1d0c7a0b0b 100644 --- a/tests/text_area/test_history.py +++ b/tests/text_area/test_history.py @@ -6,7 +6,6 @@ from textual.app import App, ComposeResult from textual.events import Paste -from textual.pilot import Pilot from textual.widgets import TextArea from textual.widgets.text_area import EditHistory, Selection @@ -57,300 +56,346 @@ async def text_area(pilot): return pilot.app.text_area -async def test_simple_undo_redo(pilot, text_area: TextArea): - text_area.insert("123", (0, 0)) +async def test_simple_undo_redo(): + app = TextAreaApp() + async with app.run_test() as pilot: + text_area = app.text_area + text_area.insert("123", (0, 0)) - assert text_area.text == "123" - text_area.undo() - assert text_area.text == "" - text_area.redo() - assert text_area.text == "123" + assert text_area.text == "123" + text_area.undo() + assert text_area.text == "" + text_area.redo() + assert text_area.text == "123" -async def test_undo_selection_retained(pilot: Pilot, text_area: TextArea): +async def test_undo_selection_retained(): # Select a range of text and press backspace. - text_area.text = SIMPLE_TEXT - text_area.selection = Selection((0, 0), (2, 3)) - await pilot.press("backspace") - assert text_area.text == "NO\nPQRST\nUVWXY\nZ\n" - assert text_area.selection == Selection.cursor((0, 0)) - - # Undo the deletion - the text comes back, and the selection is restored. - text_area.undo() - assert text_area.selection == Selection((0, 0), (2, 3)) - assert text_area.text == SIMPLE_TEXT - - # Redo the deletion - the text is gone again. The selection goes to the post-delete location. - text_area.redo() - assert text_area.text == "NO\nPQRST\nUVWXY\nZ\n" - assert text_area.selection == Selection.cursor((0, 0)) - - -async def test_undo_checkpoint_created_on_cursor_move( - pilot: Pilot, text_area: TextArea -): - text_area.text = SIMPLE_TEXT - # Characters are inserted on line 0 and 1. - checkpoint_one = text_area.text - checkpoint_one_selection = text_area.selection - await pilot.press("1") # Added to initial batch. - - # This cursor movement ensures a new checkpoint is created. - post_insert_one_location = text_area.selection - await pilot.press("down") - - checkpoint_two = text_area.text - checkpoint_two_selection = text_area.selection - await pilot.press("2") # Added to new batch. - - checkpoint_three = text_area.text - checkpoint_three_selection = text_area.selection - - # Going back to checkpoint two - text_area.undo() - assert text_area.text == checkpoint_two - assert text_area.selection == checkpoint_two_selection - - # Back again to checkpoint one (initial state) - text_area.undo() - assert text_area.text == checkpoint_one - assert text_area.selection == checkpoint_one_selection - - # Redo to move forward to checkpoint two. - text_area.redo() - assert text_area.text == checkpoint_two - assert text_area.selection == post_insert_one_location - - # Redo to move forward to checkpoint three. - text_area.redo() - assert text_area.text == checkpoint_three - assert text_area.selection == checkpoint_three_selection - - -async def test_setting_text_property_resets_history(pilot: Pilot, text_area: TextArea): - await pilot.press("1") - - # Programmatically setting text, which should invalidate the history - text = "Hello, world!" - text_area.text = text + app = TextAreaApp() + async with app.run_test() as pilot: + text_area = app.text_area + text_area.text = SIMPLE_TEXT + text_area.selection = Selection((0, 0), (2, 3)) + await pilot.press("backspace") + assert text_area.text == "NO\nPQRST\nUVWXY\nZ\n" + assert text_area.selection == Selection.cursor((0, 0)) - # The undo doesn't do anything, since we set the `text` property. - text_area.undo() - assert text_area.text == text + # Undo the deletion - the text comes back, and the selection is restored. + text_area.undo() + assert text_area.selection == Selection((0, 0), (2, 3)) + assert text_area.text == SIMPLE_TEXT + # Redo the deletion - the text is gone again. The selection goes to the post-delete location. + text_area.redo() + assert text_area.text == "NO\nPQRST\nUVWXY\nZ\n" + assert text_area.selection == Selection.cursor((0, 0)) -async def test_edits_batched_by_time(pilot: Pilot, text_area: TextArea): - # The first "12" is batched since they happen within 2 seconds. - text_area.history.mock_time = 0 - await pilot.press("1") - text_area.history.mock_time = 1.0 - await pilot.press("2") +async def test_undo_checkpoint_created_on_cursor_move(): + app = TextAreaApp() + async with app.run_test() as pilot: + text_area = app.text_area + text_area.text = SIMPLE_TEXT + # Characters are inserted on line 0 and 1. + checkpoint_one = text_area.text + checkpoint_one_selection = text_area.selection + await pilot.press("1") # Added to initial batch. + + # This cursor movement ensures a new checkpoint is created. + post_insert_one_location = text_area.selection + await pilot.press("down") + + checkpoint_two = text_area.text + checkpoint_two_selection = text_area.selection + await pilot.press("2") # Added to new batch. + + checkpoint_three = text_area.text + checkpoint_three_selection = text_area.selection + + # Going back to checkpoint two + text_area.undo() + assert text_area.text == checkpoint_two + assert text_area.selection == checkpoint_two_selection + + # Back again to checkpoint one (initial state) + text_area.undo() + assert text_area.text == checkpoint_one + assert text_area.selection == checkpoint_one_selection + + # Redo to move forward to checkpoint two. + text_area.redo() + assert text_area.text == checkpoint_two + assert text_area.selection == post_insert_one_location + + # Redo to move forward to checkpoint three. + text_area.redo() + assert text_area.text == checkpoint_three + assert text_area.selection == checkpoint_three_selection + + +async def test_setting_text_property_resets_history(): + app = TextAreaApp() + async with app.run_test() as pilot: + text_area = app.text_area + await pilot.press("1") - # Since "3" appears 10 seconds later, it's in a separate batch. - text_area.history.mock_time += 10.0 - await pilot.press("3") + # Programmatically setting text, which should invalidate the history + text = "Hello, world!" + text_area.text = text - assert text_area.text == "123" + # The undo doesn't do anything, since we set the `text` property. + text_area.undo() + assert text_area.text == text - text_area.undo() - assert text_area.text == "12" - text_area.undo() - assert text_area.text == "" +async def test_edits_batched_by_time(): + app = TextAreaApp() + async with app.run_test() as pilot: + text_area = app.text_area + # The first "12" is batched since they happen within 2 seconds. + text_area.history.mock_time = 0 + await pilot.press("1") + text_area.history.mock_time = 1.0 + await pilot.press("2") -async def test_undo_checkpoint_character_limit_reached( - pilot: Pilot, text_area: TextArea -): - await pilot.press("1") - # Since the insertion below is > 100 characters it goes to a new batch. - text_area.insert("2" * 120) + # Since "3" appears 10 seconds later, it's in a separate batch. + text_area.history.mock_time += 10.0 + await pilot.press("3") - text_area.undo() - assert text_area.text == "1" - text_area.undo() - assert text_area.text == "" + assert text_area.text == "123" + text_area.undo() + assert text_area.text == "12" -async def test_redo_with_no_undo_is_noop(text_area: TextArea): - text_area.text = SIMPLE_TEXT - text_area.redo() - assert text_area.text == SIMPLE_TEXT + text_area.undo() + assert text_area.text == "" -async def test_undo_with_empty_undo_stack_is_noop(text_area: TextArea): - text_area.text = SIMPLE_TEXT - text_area.undo() - assert text_area.text == SIMPLE_TEXT +async def test_undo_checkpoint_character_limit_reached(): + app = TextAreaApp() + async with app.run_test() as pilot: + text_area = app.text_area + await pilot.press("1") + # Since the insertion below is > 100 characters it goes to a new batch. + text_area.insert("2" * 120) + text_area.undo() + assert text_area.text == "1" + text_area.undo() + assert text_area.text == "" -async def test_redo_stack_cleared_on_edit(pilot: Pilot, text_area: TextArea): - text_area.text = "" - await pilot.press("1") - text_area.history.checkpoint() - await pilot.press("2") - text_area.history.checkpoint() - await pilot.press("3") - text_area.undo() - text_area.undo() - text_area.undo() - assert text_area.text == "" - assert text_area.selection == Selection.cursor((0, 0)) +async def test_redo_with_no_undo_is_noop(): + app = TextAreaApp() + async with app.run_test() as pilot: + text_area = app.text_area + text_area.text = SIMPLE_TEXT + text_area.redo() + assert text_area.text == SIMPLE_TEXT - # Redo stack has 3 edits in it now. - await pilot.press("f") - assert text_area.text == "f" - assert text_area.selection == Selection.cursor((0, 1)) - # Redo stack is cleared because of the edit, so redo has no effect. - text_area.redo() - assert text_area.text == "f" - assert text_area.selection == Selection.cursor((0, 1)) - text_area.redo() - assert text_area.text == "f" - assert text_area.selection == Selection.cursor((0, 1)) +async def test_undo_with_empty_undo_stack_is_noop(): + app = TextAreaApp() + async with app.run_test() as pilot: + text_area = app.text_area + text_area.text = SIMPLE_TEXT + text_area.undo() + assert text_area.text == SIMPLE_TEXT -async def test_inserts_not_batched_with_deletes(pilot: Pilot, text_area: TextArea): +async def test_redo_stack_cleared_on_edit(): + app = TextAreaApp() + async with app.run_test() as pilot: + text_area = app.text_area + text_area.text = "" + await pilot.press("1") + text_area.history.checkpoint() + await pilot.press("2") + text_area.history.checkpoint() + await pilot.press("3") + + text_area.undo() + text_area.undo() + text_area.undo() + assert text_area.text == "" + assert text_area.selection == Selection.cursor((0, 0)) + + # Redo stack has 3 edits in it now. + await pilot.press("f") + assert text_area.text == "f" + assert text_area.selection == Selection.cursor((0, 1)) + + # Redo stack is cleared because of the edit, so redo has no effect. + text_area.redo() + assert text_area.text == "f" + assert text_area.selection == Selection.cursor((0, 1)) + text_area.redo() + assert text_area.text == "f" + assert text_area.selection == Selection.cursor((0, 1)) + + +async def test_inserts_not_batched_with_deletes(): # 3 batches here: __1___ ___________2____________ __3__ - await pilot.press(*"123", "backspace", "backspace", *"23") - - assert text_area.text == "123" - - # Undo batch 1: the "23" insertion. - text_area.undo() - assert text_area.text == "1" - # Undo batch 2: the double backspace. - text_area.undo() - assert text_area.text == "123" - - # Undo batch 3: the "123" insertion. - text_area.undo() - assert text_area.text == "" + app = TextAreaApp() + async with app.run_test() as pilot: + text_area = app.text_area + await pilot.press(*"123", "backspace", "backspace", *"23") -async def test_paste_is_an_isolated_batch(pilot: Pilot, text_area: TextArea): - pilot.app.post_message(Paste("hello ")) - pilot.app.post_message(Paste("world")) - await pilot.pause() + assert text_area.text == "123" - assert text_area.text == "hello world" + # Undo batch 1: the "23" insertion. + text_area.undo() + assert text_area.text == "1" - await pilot.press("!") + # Undo batch 2: the double backspace. + text_area.undo() + assert text_area.text == "123" - # The insertion of "!" does not get batched with the paste of "world". - text_area.undo() - assert text_area.text == "hello world" + # Undo batch 3: the "123" insertion. + text_area.undo() + assert text_area.text == "" - text_area.undo() - assert text_area.text == "hello " - text_area.undo() - assert text_area.text == "" +async def test_paste_is_an_isolated_batch(): + app = TextAreaApp() + async with app.run_test() as pilot: + text_area = app.text_area + pilot.app.post_message(Paste("hello ")) + pilot.app.post_message(Paste("world")) + await pilot.pause() + assert text_area.text == "hello world" -async def test_focus_creates_checkpoint(pilot: Pilot, text_area: TextArea): - await pilot.press(*"123") - text_area.has_focus = False - text_area.has_focus = True - await pilot.press(*"456") - assert text_area.text == "123456" + await pilot.press("!") - # Since we re-focused, a checkpoint exists between 123 and 456, - # so when we use undo, only the 456 is removed. - text_area.undo() - assert text_area.text == "123" + # The insertion of "!" does not get batched with the paste of "world". + text_area.undo() + assert text_area.text == "hello world" + text_area.undo() + assert text_area.text == "hello " -async def test_undo_redo_deletions_batched(pilot: Pilot, text_area: TextArea): - text_area.text = SIMPLE_TEXT - text_area.selection = Selection((0, 2), (1, 2)) + text_area.undo() + assert text_area.text == "" - # Perform a single delete of some selected text. It'll live in it's own - # batch since it's a multi-line operation. - await pilot.press("backspace") - checkpoint_one = "ABHIJ\nKLMNO\nPQRST\nUVWXY\nZ\n" - assert text_area.text == checkpoint_one - assert text_area.selection == Selection.cursor((0, 2)) - # Pressing backspace a few times to delete more characters. - await pilot.press("backspace", "backspace", "backspace") - checkpoint_two = "HIJ\nKLMNO\nPQRST\nUVWXY\nZ\n" - assert text_area.text == checkpoint_two - assert text_area.selection == Selection.cursor((0, 0)) +async def test_focus_creates_checkpoint(): + app = TextAreaApp() + async with app.run_test() as pilot: + text_area = app.text_area + await pilot.press(*"123") + text_area.has_focus = False + text_area.has_focus = True + await pilot.press(*"456") + assert text_area.text == "123456" - # When we undo, the 3 deletions above should be batched, but not - # the original deletion since it contains a newline character. - text_area.undo() - assert text_area.text == checkpoint_one - assert text_area.selection == Selection.cursor((0, 2)) + # Since we re-focused, a checkpoint exists between 123 and 456, + # so when we use undo, only the 456 is removed. + text_area.undo() + assert text_area.text == "123" - # Undoing again restores us back to our initial text and selection. - text_area.undo() - assert text_area.text == SIMPLE_TEXT - assert text_area.selection == Selection((0, 2), (1, 2)) - # At this point, the undo stack contains two items, so we can redo twice. +async def test_undo_redo_deletions_batched(): + app = TextAreaApp() + async with app.run_test() as pilot: + text_area = app.text_area + text_area.text = SIMPLE_TEXT + text_area.selection = Selection((0, 2), (1, 2)) + + # Perform a single delete of some selected text. It'll live in it's own + # batch since it's a multi-line operation. + await pilot.press("backspace") + checkpoint_one = "ABHIJ\nKLMNO\nPQRST\nUVWXY\nZ\n" + assert text_area.text == checkpoint_one + assert text_area.selection == Selection.cursor((0, 2)) + + # Pressing backspace a few times to delete more characters. + await pilot.press("backspace", "backspace", "backspace") + checkpoint_two = "HIJ\nKLMNO\nPQRST\nUVWXY\nZ\n" + assert text_area.text == checkpoint_two + assert text_area.selection == Selection.cursor((0, 0)) + + # When we undo, the 3 deletions above should be batched, but not + # the original deletion since it contains a newline character. + text_area.undo() + assert text_area.text == checkpoint_one + assert text_area.selection == Selection.cursor((0, 2)) + + # Undoing again restores us back to our initial text and selection. + text_area.undo() + assert text_area.text == SIMPLE_TEXT + assert text_area.selection == Selection((0, 2), (1, 2)) + + # At this point, the undo stack contains two items, so we can redo twice. + + # Redo to go back to checkpoint one. + text_area.redo() + assert text_area.text == checkpoint_one + assert text_area.selection == Selection.cursor((0, 2)) + + # Redo again to go back to checkpoint two + text_area.redo() + assert text_area.text == checkpoint_two + assert text_area.selection == Selection.cursor((0, 0)) + + # Redo again does nothing. + text_area.redo() + assert text_area.text == checkpoint_two + assert text_area.selection == Selection.cursor((0, 0)) + + +async def test_max_checkpoints(): + app = TextAreaApp() + async with app.run_test() as pilot: + text_area = app.text_area + assert len(text_area.history.undo_stack) == 0 + for index in range(MAX_CHECKPOINTS): + # Press enter since that will ensure a checkpoint is created. + await pilot.press("enter") - # Redo to go back to checkpoint one. - text_area.redo() - assert text_area.text == checkpoint_one - assert text_area.selection == Selection.cursor((0, 2)) + assert len(text_area.history.undo_stack) == MAX_CHECKPOINTS + await pilot.press("enter") + # Ensure we don't go over the limit. + assert len(text_area.history.undo_stack) == MAX_CHECKPOINTS - # Redo again to go back to checkpoint two - text_area.redo() - assert text_area.text == checkpoint_two - assert text_area.selection == Selection.cursor((0, 0)) - # Redo again does nothing. - text_area.redo() - assert text_area.text == checkpoint_two - assert text_area.selection == Selection.cursor((0, 0)) +async def test_redo_stack(): + app = TextAreaApp() + async with app.run_test() as pilot: + text_area = app.text_area + assert len(text_area.history.redo_stack) == 0 + await pilot.press("enter") + await pilot.press(*"123") + assert len(text_area.history.undo_stack) == 2 + assert len(text_area.history.redo_stack) == 0 + text_area.undo() + assert len(text_area.history.undo_stack) == 1 + assert len(text_area.history.redo_stack) == 1 + text_area.undo() + assert len(text_area.history.undo_stack) == 0 + assert len(text_area.history.redo_stack) == 2 + text_area.redo() + assert len(text_area.history.undo_stack) == 1 + assert len(text_area.history.redo_stack) == 1 + text_area.redo() + assert len(text_area.history.undo_stack) == 2 + assert len(text_area.history.redo_stack) == 0 + + +async def test_backward_selection_undo_redo(): + app = TextAreaApp() + async with app.run_test() as pilot: + text_area = app.text_area + # Failed prior to https://github.com/Textualize/textual/pull/4352 + text_area.text = SIMPLE_TEXT + text_area.selection = Selection((3, 2), (0, 0)) + await pilot.press("a") -async def test_max_checkpoints(pilot: Pilot, text_area: TextArea): - assert len(text_area.history.undo_stack) == 0 - for index in range(MAX_CHECKPOINTS): - # Press enter since that will ensure a checkpoint is created. - await pilot.press("enter") + text_area.undo() + await pilot.press("down", "down", "down", "down") - assert len(text_area.history.undo_stack) == MAX_CHECKPOINTS - await pilot.press("enter") - # Ensure we don't go over the limit. - assert len(text_area.history.undo_stack) == MAX_CHECKPOINTS - - -async def test_redo_stack(pilot: Pilot, text_area: TextArea): - assert len(text_area.history.redo_stack) == 0 - await pilot.press("enter") - await pilot.press(*"123") - assert len(text_area.history.undo_stack) == 2 - assert len(text_area.history.redo_stack) == 0 - text_area.undo() - assert len(text_area.history.undo_stack) == 1 - assert len(text_area.history.redo_stack) == 1 - text_area.undo() - assert len(text_area.history.undo_stack) == 0 - assert len(text_area.history.redo_stack) == 2 - text_area.redo() - assert len(text_area.history.undo_stack) == 1 - assert len(text_area.history.redo_stack) == 1 - text_area.redo() - assert len(text_area.history.undo_stack) == 2 - assert len(text_area.history.redo_stack) == 0 - - -async def test_backward_selection_undo_redo(pilot: Pilot, text_area: TextArea): - # Failed prior to https://github.com/Textualize/textual/pull/4352 - text_area.text = SIMPLE_TEXT - text_area.selection = Selection((3, 2), (0, 0)) - - await pilot.press("a") - - text_area.undo() - await pilot.press("down", "down", "down", "down") - - assert text_area.text == SIMPLE_TEXT + assert text_area.text == SIMPLE_TEXT