diff --git a/CHANGELOG.md b/CHANGELOG.md index 50dc3f19e5..93653a5532 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,10 +7,19 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## Unreleased +### Changed + +- Breaking change: `DOMNode.has_pseudo_class` now accepts a single name only https://github.com/Textualize/textual/pull/3970 + +### Added + +- Added `DOMNode.has_pseudo_classes` https://github.com/Textualize/textual/pull/3970 + ### Fixed - Parameter `animate` from `DataTable.move_cursor` was being ignored https://github.com/Textualize/textual/issues/3840 + ## [0.47.1] - 2023-01-05 ### Fixed diff --git a/src/textual/css/model.py b/src/textual/css/model.py index d31217324a..b2bf25f9a8 100644 --- a/src/textual/css/model.py +++ b/src/textual/css/model.py @@ -2,6 +2,7 @@ from dataclasses import dataclass, field from enum import Enum +from functools import partial from typing import TYPE_CHECKING, Iterable import rich.repr @@ -12,6 +13,8 @@ from .types import Specificity3 if TYPE_CHECKING: + from typing import Callable + from typing_extensions import Self from ..dom import DOMNode @@ -43,6 +46,67 @@ class CombinatorType(Enum): """Selector is an immediate child of the previous selector""" +def _check_universal(name: str, node: DOMNode) -> bool: + """Check node matches universal selector. + + Args: + name: Selector name. + node: A DOM node. + + Returns: + `True` if the selector matches. + """ + return True + + +def _check_type(name: str, node: DOMNode) -> bool: + """Check node matches a type selector. + + Args: + name: Selector name. + node: A DOM node. + + Returns: + `True` if the selector matches. + """ + return name in node._css_type_names + + +def _check_class(name: str, node: DOMNode) -> bool: + """Check node matches a class selector. + + Args: + name: Selector name. + node: A DOM node. + + Returns: + `True` if the selector matches. + """ + return name in node._classes + + +def _check_id(name: str, node: DOMNode) -> bool: + """Check node matches an ID selector. + + Args: + name: Selector name. + node: A DOM node. + + Returns: + `True` if the selector matches. + """ + return node.id == name + + +_CHECKS = { + SelectorType.UNIVERSAL: _check_universal, + SelectorType.TYPE: _check_type, + SelectorType.CLASS: _check_class, + SelectorType.ID: _check_id, + SelectorType.NESTED: _check_universal, +} + + @dataclass class Selector: """Represents a CSS selector. @@ -57,14 +121,17 @@ class Selector: name: str combinator: CombinatorType = CombinatorType.DESCENDENT type: SelectorType = SelectorType.TYPE - pseudo_classes: list[str] = field(default_factory=list) + pseudo_classes: set[str] = field(default_factory=set) specificity: Specificity3 = field(default_factory=lambda: (0, 0, 0)) advance: int = 1 + def __post_init__(self) -> None: + self._check: Callable[[DOMNode], bool] = partial(_CHECKS[self.type], self.name) + @property def css(self) -> str: """Rebuilds the selector as it would appear in CSS.""" - pseudo_suffix = "".join(f":{name}" for name in self.pseudo_classes) + pseudo_suffix = "".join(f":{name}" for name in sorted(self.pseudo_classes)) if self.type == SelectorType.UNIVERSAL: return "*" elif self.type == SelectorType.TYPE: @@ -74,21 +141,13 @@ def css(self) -> str: else: return f"#{self.name}{pseudo_suffix}" - def __post_init__(self) -> None: - self._checks = { - SelectorType.UNIVERSAL: self._check_universal, - SelectorType.TYPE: self._check_type, - SelectorType.CLASS: self._check_class, - SelectorType.ID: self._check_id, - } - def _add_pseudo_class(self, pseudo_class: str) -> None: """Adds a pseudo class and updates specificity. Args: pseudo_class: Name of pseudo class. """ - self.pseudo_classes.append(pseudo_class) + self.pseudo_classes.add(pseudo_class) specificity1, specificity2, specificity3 = self.specificity self.specificity = (specificity1, specificity2 + 1, specificity3) @@ -101,31 +160,11 @@ def check(self, node: DOMNode) -> bool: Returns: True if the selector matches, otherwise False. """ - return self._checks[self.type](node) - - def _check_universal(self, node: DOMNode) -> bool: - return node.has_pseudo_class(*self.pseudo_classes) - - def _check_type(self, node: DOMNode) -> bool: - if self.name not in node._css_type_names: - return False - if self.pseudo_classes and not node.has_pseudo_class(*self.pseudo_classes): - return False - return True - - def _check_class(self, node: DOMNode) -> bool: - if not node.has_class(self.name): - return False - if self.pseudo_classes and not node.has_pseudo_class(*self.pseudo_classes): - return False - return True - - def _check_id(self, node: DOMNode) -> bool: - if node.id != self.name: - return False - if self.pseudo_classes and not node.has_pseudo_class(*self.pseudo_classes): - return False - return True + return self._check(node) and ( + node.has_pseudo_classes(self.pseudo_classes) + if self.pseudo_classes + else True + ) @dataclass diff --git a/src/textual/css/parse.py b/src/textual/css/parse.py index d41e4b2935..b9e59d94ba 100644 --- a/src/textual/css/parse.py +++ b/src/textual/css/parse.py @@ -208,11 +208,9 @@ def combine_selectors( nested_selector = selectors2[0] merged_selector = dataclasses.replace( final_selector, - pseudo_classes=list( - set( - final_selector.pseudo_classes - + nested_selector.pseudo_classes - ) + pseudo_classes=( + final_selector.pseudo_classes + | nested_selector.pseudo_classes ), specificity=_add_specificity( final_selector.specificity, nested_selector.specificity diff --git a/src/textual/css/stylesheet.py b/src/textual/css/stylesheet.py index 587fa166c8..e56ed6a873 100644 --- a/src/textual/css/stylesheet.py +++ b/src/textual/css/stylesheet.py @@ -30,6 +30,8 @@ class StylesheetParseError(StylesheetError): + """Raised when the stylesheet could not be parsed.""" + def __init__(self, errors: StylesheetErrors) -> None: self.errors = errors @@ -38,6 +40,8 @@ def __rich__(self) -> RenderableType: class StylesheetErrors: + """A renderable for stylesheet errors.""" + def __init__(self, rules: list[RuleSet]) -> None: self.rules = rules self.variables: dict[str, str] = {} @@ -134,6 +138,8 @@ class CssSource(NamedTuple): @rich.repr.auto(angular=True) class Stylesheet: + """A Stylsheet generated from Textual CSS.""" + def __init__(self, *, variables: dict[str, str] | None = None) -> None: self._rules: list[RuleSet] = [] self._rules_map: dict[str, list[RuleSet]] | None = None @@ -183,6 +189,10 @@ def rules_map(self) -> dict[str, list[RuleSet]]: @property def css(self) -> str: + """The equivalent TCSS for this stylesheet. + + Note that this may not produce the same content as the file(s) used to generate the stylesheet. + """ return "\n\n".join(rule_set.css for rule_set in self.rules) def copy(self) -> Stylesheet: @@ -407,9 +417,18 @@ def reparse(self) -> None: @classmethod def _check_rule( - cls, rule: RuleSet, css_path_nodes: list[DOMNode] + cls, rule_set: RuleSet, css_path_nodes: list[DOMNode] ) -> Iterable[Specificity3]: - for selector_set in rule.selector_set: + """Check a rule set, return specificity of applicable rules. + + Args: + rule_set: A rule set. + css_path_nodes: A list of the nodes from the App to the node being checked. + + Yields: + Specificity of any matching selectors. + """ + for selector_set in rule_set.selector_set: if _check_selectors(selector_set.selectors, css_path_nodes): yield selector_set.specificity diff --git a/src/textual/dom.py b/src/textual/dom.py index a9c832788e..461e9acea7 100644 --- a/src/textual/dom.py +++ b/src/textual/dom.py @@ -166,7 +166,7 @@ def __init__( id: str | None = None, classes: str | None = None, ) -> None: - self._classes = set() + self._classes: set[str] = set() self._name = name self._id = None if id is not None: @@ -575,8 +575,7 @@ def css_identifier_styled(self) -> Text: @property def pseudo_classes(self) -> frozenset[str]: """A (frozen) set of all pseudo classes.""" - pseudo_classes = frozenset(self.get_pseudo_classes()) - return pseudo_classes + return frozenset(self.get_pseudo_classes()) @property def css_path_nodes(self) -> list[DOMNode]: @@ -1276,17 +1275,27 @@ def toggle_class(self, *class_names: str) -> Self: self._update_styles() return self - def has_pseudo_class(self, *class_names: str) -> bool: - """Check for pseudo classes (such as hover, focus etc) + def has_pseudo_class(self, class_name: str) -> bool: + """Check the node has the given pseudo class. Args: - *class_names: The pseudo classes to check for. + class_name: The pseudo class to check for. Returns: - `True` if the DOM node has those pseudo classes, `False` if not. + `True` if the DOM node has the pseudo class, `False` if not. """ - has_pseudo_classes = self.pseudo_classes.issuperset(class_names) - return has_pseudo_classes + return class_name in self.get_pseudo_classes() + + def has_pseudo_classes(self, class_names: set[str]) -> bool: + """Check the node has all the given pseudo classes. + + Args: + class_names: Set of class names to check for. + + Returns: + `True` if all pseudo class names are present. + """ + return class_names.issubset(self.get_pseudo_classes()) def refresh(self, *, repaint: bool = True, layout: bool = False) -> Self: return self diff --git a/src/textual/widget.py b/src/textual/widget.py index d0c549a5ae..ffe458d524 100644 --- a/src/textual/widget.py +++ b/src/textual/widget.py @@ -2876,14 +2876,6 @@ def get_pseudo_classes(self) -> Iterable[str]: Returns: Names of the pseudo classes. """ - node: MessagePump | None = self - while isinstance(node, Widget): - if node.disabled: - yield "disabled" - break - node = node._parent - else: - yield "enabled" if self.mouse_over: yield "hover" if self.has_focus: @@ -2892,6 +2884,14 @@ def get_pseudo_classes(self) -> Iterable[str]: yield "blur" if self.can_focus: yield "can-focus" + node: MessagePump | None = self + while isinstance(node, Widget): + if node.disabled: + yield "disabled" + break + node = node._parent + else: + yield "enabled" try: focused = self.screen.focused except NoScreen: