diff --git a/CHANGELOG.md b/CHANGELOG.md index 158f4a4b39..1fe716f5f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,10 +11,13 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Added `DOMNode.check_consume_key` https://github.com/Textualize/textual/pull/4940 - Added `App.ESCAPE_TO_MINIMIZE`, `App.screen_to_minimize`, and `Screen.ESCAPE_TO_MINIMIZE` https://github.com/Textualize/textual/pull/4951 +- Added `DOMNode.query_exactly_one` https://github.com/Textualize/textual/pull/4950 +- Added `SelectorSet.is_simple` https://github.com/Textualize/textual/pull/4950 ### Changed - KeyPanel will show multiple keys if bound to the same action https://github.com/Textualize/textual/pull/4940 +- Breaking change: `DOMNode.query_one` will not `raise TooManyMatches` https://github.com/Textualize/textual/pull/4950 ## [0.78.0] - 2024-08-27 diff --git a/docs/guide/queries.md b/docs/guide/queries.md index d33659f382..c0ce0be51f 100644 --- a/docs/guide/queries.md +++ b/docs/guide/queries.md @@ -21,7 +21,6 @@ send_button = self.query_one("#send") This will retrieve a widget with an ID of `send`, if there is exactly one. If there are no matching widgets, Textual will raise a [NoMatches][textual.css.query.NoMatches] exception. -If there is more than one match, Textual will raise a [TooManyMatches][textual.css.query.TooManyMatches] exception. You can also add a second parameter for the expected type, which will ensure that you get the type you are expecting. diff --git a/src/textual/_node_list.py b/src/textual/_node_list.py index 198558777d..52555f9d61 100644 --- a/src/textual/_node_list.py +++ b/src/textual/_node_list.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from _typeshed import SupportsRichComparison + from .dom import DOMNode from .widget import Widget @@ -24,7 +25,8 @@ class NodeList(Sequence["Widget"]): Although named a list, widgets may appear only once, making them more like a set. """ - def __init__(self) -> None: + def __init__(self, parent: DOMNode | None = None) -> None: + self._parent = parent # The nodes in the list self._nodes: list[Widget] = [] self._nodes_set: set[Widget] = set() @@ -52,6 +54,13 @@ def __len__(self) -> int: def __contains__(self, widget: object) -> bool: return widget in self._nodes + def updated(self) -> None: + """Mark the nodes as having been updated.""" + self._updates += 1 + node = self._parent + while node is not None and (node := node._parent) is not None: + node._nodes._updates += 1 + def _sort( self, *, @@ -69,7 +78,7 @@ def _sort( else: self._nodes.sort(key=key, reverse=reverse) - self._updates += 1 + self.updated() def index(self, widget: Any, start: int = 0, stop: int = sys.maxsize) -> int: """Return the index of the given widget. @@ -102,7 +111,7 @@ def _append(self, widget: Widget) -> None: if widget_id is not None: self._ensure_unique_id(widget_id) self._nodes_by_id[widget_id] = widget - self._updates += 1 + self.updated() def _insert(self, index: int, widget: Widget) -> None: """Insert a Widget. @@ -117,7 +126,7 @@ def _insert(self, index: int, widget: Widget) -> None: if widget_id is not None: self._ensure_unique_id(widget_id) self._nodes_by_id[widget_id] = widget - self._updates += 1 + self.updated() def _ensure_unique_id(self, widget_id: str) -> None: if widget_id in self._nodes_by_id: @@ -141,7 +150,7 @@ def _remove(self, widget: Widget) -> None: widget_id = widget.id if widget_id in self._nodes_by_id: del self._nodes_by_id[widget_id] - self._updates += 1 + self.updated() def _clear(self) -> None: """Clear the node list.""" @@ -149,7 +158,7 @@ def _clear(self) -> None: self._nodes.clear() self._nodes_set.clear() self._nodes_by_id.clear() - self._updates += 1 + self.updated() def __iter__(self) -> Iterator[Widget]: return iter(self._nodes) diff --git a/src/textual/css/model.py b/src/textual/css/model.py index b2bf25f9a8..cf5f55b83b 100644 --- a/src/textual/css/model.py +++ b/src/textual/css/model.py @@ -193,6 +193,15 @@ def __post_init__(self) -> None: def css(self) -> str: return RuleSet._selector_to_css(self.selectors) + @property + def is_simple(self) -> bool: + """Are all the selectors simple (i.e. only dependent on static DOM state).""" + simple_types = {SelectorType.ID, SelectorType.TYPE} + return all( + (selector.type in simple_types and not selector.pseudo_classes) + for selector in self.selectors + ) + def __rich_repr__(self) -> rich.repr.Result: selectors = RuleSet._selector_to_css(self.selectors) yield selectors diff --git a/src/textual/dom.py b/src/textual/dom.py index b75bbda542..0fc190e55a 100644 --- a/src/textual/dom.py +++ b/src/textual/dom.py @@ -33,11 +33,14 @@ from ._node_list import NodeList from ._types import WatchCallbackType from .binding import Binding, BindingsMap, BindingType +from .cache import LRUCache from .color import BLACK, WHITE, Color from .css._error_tools import friendly_list from .css.constants import VALID_DISPLAY, VALID_VISIBILITY from .css.errors import DeclarationError, StyleValueError -from .css.parse import parse_declarations +from .css.match import match +from .css.parse import parse_declarations, parse_selectors +from .css.query import NoMatches, TooManyMatches from .css.styles import RenderStyles, Styles from .css.tokenize import IDENTIFIER from .message_pump import MessagePump @@ -60,7 +63,7 @@ from .worker import Worker, WorkType, ResultType # Unused & ignored imports are needed for the docs to link to these objects: - from .css.query import NoMatches, TooManyMatches, WrongType # type: ignore # noqa: F401 + from .css.query import WrongType # type: ignore # noqa: F401 from typing_extensions import Literal @@ -74,6 +77,10 @@ ReactiveType = TypeVar("ReactiveType") +QueryOneCacheKey: TypeAlias = "tuple[int, str, Type[Widget] | None]" +"""The key used to cache query_one results.""" + + class BadIdentifier(Exception): """Exception raised if you supply a `id` attribute or class name in the wrong format.""" @@ -184,13 +191,14 @@ def __init__( self._name = name self._id = None if id is not None: - self.id = id + check_identifiers("id", id) + self._id = id _classes = classes.split() if classes else [] check_identifiers("class name", *_classes) self._classes.update(_classes) - self._nodes: NodeList = NodeList() + self._nodes: NodeList = NodeList(self) self._css_styles: Styles = Styles(self) self._inline_styles: Styles = Styles(self) self.styles: RenderStyles = RenderStyles( @@ -213,6 +221,8 @@ def __init__( dict[str, tuple[MessagePump, Reactive | object]] | None ) = None self._pruning = False + self._query_one_cache: LRUCache[QueryOneCacheKey, DOMNode] = LRUCache(1024) + super().__init__() def set_reactive( @@ -741,7 +751,7 @@ def id(self, new_id: str) -> str: ValueError: If the ID has already been set. """ check_identifiers("id", new_id) - + self._nodes.updated() if self._id is not None: raise ValueError( f"Node 'id' attribute may not be changed once set (current id={self._id!r})" @@ -1393,21 +1403,110 @@ def query_one( Raises: WrongType: If the wrong type was found. NoMatches: If no node matches the query. - TooManyMatches: If there is more than one matching node in the query. Returns: A widget matching the selector. """ _rich_traceback_omit = True - from .css.query import DOMQuery if isinstance(selector, str): query_selector = selector else: query_selector = selector.__name__ - query: DOMQuery[Widget] = DOMQuery(self, filter=query_selector) - return query.only_one() if expect_type is None else query.only_one(expect_type) + selector_set = parse_selectors(query_selector) + + if all(selectors.is_simple for selectors in selector_set): + cache_key = (self._nodes._updates, query_selector, expect_type) + cached_result = self._query_one_cache.get(cache_key) + if cached_result is not None: + return cached_result + else: + cache_key = None + + for node in walk_depth_first(self, with_root=False): + if not match(selector_set, node): + continue + if expect_type is not None and not isinstance(node, expect_type): + continue + if cache_key is not None: + self._query_one_cache[cache_key] = node + return node + + raise NoMatches(f"No nodes match {selector!r} on {self!r}") + + if TYPE_CHECKING: + + @overload + def query_exactly_one(self, selector: str) -> Widget: ... + + @overload + def query_exactly_one(self, selector: type[QueryType]) -> QueryType: ... + + @overload + def query_exactly_one( + self, selector: str, expect_type: type[QueryType] + ) -> QueryType: ... + + def query_exactly_one( + self, + selector: str | type[QueryType], + expect_type: type[QueryType] | None = None, + ) -> QueryType | Widget: + """Get a widget from this widget's children that matches a selector or widget type. + + !!! Note + This method is similar to [query_one][textual.dom.DOMNode.query_one]. + The only difference is that it will raise `TooManyMatches` if there is more than a single match. + + Args: + selector: A selector or widget type. + expect_type: Require the object be of the supplied type, or None for any type. + + Raises: + WrongType: If the wrong type was found. + NoMatches: If no node matches the query. + TooManyMatches: If there is more than one matching node in the query (and `exactly_one==True`). + + Returns: + A widget matching the selector. + """ + _rich_traceback_omit = True + + if isinstance(selector, str): + query_selector = selector + else: + query_selector = selector.__name__ + + selector_set = parse_selectors(query_selector) + + if all(selectors.is_simple for selectors in selector_set): + cache_key = (self._nodes._updates, query_selector, expect_type) + cached_result = self._query_one_cache.get(cache_key) + if cached_result is not None: + return cached_result + else: + cache_key = None + + children = walk_depth_first(self, with_root=False) + iter_children = iter(children) + for node in iter_children: + if not match(selector_set, node): + continue + if expect_type is not None and not isinstance(node, expect_type): + continue + for later_node in iter_children: + if match(selector_set, later_node): + if expect_type is not None and not isinstance(node, expect_type): + continue + raise TooManyMatches( + "Call to query_one resulted in more than one matched node" + ) + if cache_key is not None: + self._query_one_cache[cache_key] = node + return node + + raise NoMatches(f"No nodes match {selector!r} on {self!r}") def set_styles(self, css: str | None = None, **update_styles: Any) -> Self: """Set custom styles on this object. diff --git a/src/textual/widget.py b/src/textual/widget.py index 9296d20865..b71cc971ae 100644 --- a/src/textual/widget.py +++ b/src/textual/widget.py @@ -87,7 +87,6 @@ from .renderables.blank import Blank from .rlock import RLock from .strip import Strip -from .walk import walk_depth_first if TYPE_CHECKING: from .app import App, ComposeResult @@ -807,21 +806,14 @@ def get_widget_by_id( NoMatches: if no children could be found for this ID. WrongType: if the wrong type was found. """ - # We use Widget as a filter_type so that the inferred type of child is Widget. - for child in walk_depth_first(self, filter_type=Widget): - try: - if expect_type is None: - return child.get_child_by_id(id) - else: - return child.get_child_by_id(id, expect_type=expect_type) - except NoMatches: - pass - except WrongType as exc: - raise WrongType( - f"Descendant with id={id!r} is wrong type; expected {expect_type}," - f" got {type(child)}" - ) from exc - raise NoMatches(f"No descendant found with id={id!r}") + + widget = self.query_one(f"#{id}") + if expect_type is not None and not isinstance(widget, expect_type): + raise WrongType( + f"Descendant with id={id!r} is wrong type; expected {expect_type}," + f" got {type(widget)}" + ) + return widget def get_child_by_type(self, expect_type: type[ExpectType]) -> ExpectType: """Get the first immediate child of a given type. @@ -958,7 +950,7 @@ def _find_mount_point(self, spot: int | str | "Widget") -> tuple["Widget", int]: # can be passed to query_one. So let's use that to get a widget to # work on. if isinstance(spot, str): - spot = self.query_one(spot, Widget) + spot = self.query_exactly_one(spot, Widget) # At this point we should have a widget, either because we got given # one, or because we pulled one out of the query. First off, does it diff --git a/tests/test_query.py b/tests/test_query.py index 07f608824a..4030003866 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -103,7 +103,7 @@ class App(Widget): assert app.query_one("#widget1") == widget1 assert app.query_one("#widget1", Widget) == widget1 with pytest.raises(TooManyMatches): - _ = app.query_one(Widget) + _ = app.query_exactly_one(Widget) assert app.query("Widget.float")[0] == sidebar assert app.query("Widget.float")[0:2] == [sidebar, tooltip]