Skip to content

Commit

Permalink
optimization (#3970)
Browse files Browse the repository at this point in the history
* optimization

* fix

* changelog [skip ci]

* sort

* simplification

* simplify check

* fix and typing

* typing

* docstrings

* Apply suggestions from code review

Co-authored-by: Rodrigo Girão Serrão <[email protected]>

* Update src/textual/dom.py

Co-authored-by: Rodrigo Girão Serrão <[email protected]>

---------

Co-authored-by: Rodrigo Girão Serrão <[email protected]>
  • Loading branch information
willmcgugan and rodrigogiraoserrao authored Jan 8, 2024
1 parent 5cd5aa9 commit 8f822ae
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 60 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,19 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

## Unreleased

### Changed

- Breaking change: `DOMNode.has_pseudo_class` now accepts a single name only https://github.com/Textualize/textual/pull/3970

### Added

- Added `DOMNode.has_pseudo_classes` https://github.com/Textualize/textual/pull/3970

### Fixed

- Parameter `animate` from `DataTable.move_cursor` was being ignored https://github.com/Textualize/textual/issues/3840


## [0.47.1] - 2023-01-05

### Fixed
Expand Down
111 changes: 75 additions & 36 deletions src/textual/css/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from typing import TYPE_CHECKING, Iterable

import rich.repr
Expand All @@ -12,6 +13,8 @@
from .types import Specificity3

if TYPE_CHECKING:
from typing import Callable

from typing_extensions import Self

from ..dom import DOMNode
Expand Down Expand Up @@ -43,6 +46,67 @@ class CombinatorType(Enum):
"""Selector is an immediate child of the previous selector"""


def _check_universal(name: str, node: DOMNode) -> bool:
"""Check node matches universal selector.
Args:
name: Selector name.
node: A DOM node.
Returns:
`True` if the selector matches.
"""
return True


def _check_type(name: str, node: DOMNode) -> bool:
"""Check node matches a type selector.
Args:
name: Selector name.
node: A DOM node.
Returns:
`True` if the selector matches.
"""
return name in node._css_type_names


def _check_class(name: str, node: DOMNode) -> bool:
"""Check node matches a class selector.
Args:
name: Selector name.
node: A DOM node.
Returns:
`True` if the selector matches.
"""
return name in node._classes


def _check_id(name: str, node: DOMNode) -> bool:
"""Check node matches an ID selector.
Args:
name: Selector name.
node: A DOM node.
Returns:
`True` if the selector matches.
"""
return node.id == name


_CHECKS = {
SelectorType.UNIVERSAL: _check_universal,
SelectorType.TYPE: _check_type,
SelectorType.CLASS: _check_class,
SelectorType.ID: _check_id,
SelectorType.NESTED: _check_universal,
}


@dataclass
class Selector:
"""Represents a CSS selector.
Expand All @@ -57,14 +121,17 @@ class Selector:
name: str
combinator: CombinatorType = CombinatorType.DESCENDENT
type: SelectorType = SelectorType.TYPE
pseudo_classes: list[str] = field(default_factory=list)
pseudo_classes: set[str] = field(default_factory=set)
specificity: Specificity3 = field(default_factory=lambda: (0, 0, 0))
advance: int = 1

def __post_init__(self) -> None:
self._check: Callable[[DOMNode], bool] = partial(_CHECKS[self.type], self.name)

@property
def css(self) -> str:
"""Rebuilds the selector as it would appear in CSS."""
pseudo_suffix = "".join(f":{name}" for name in self.pseudo_classes)
pseudo_suffix = "".join(f":{name}" for name in sorted(self.pseudo_classes))
if self.type == SelectorType.UNIVERSAL:
return "*"
elif self.type == SelectorType.TYPE:
Expand All @@ -74,21 +141,13 @@ def css(self) -> str:
else:
return f"#{self.name}{pseudo_suffix}"

def __post_init__(self) -> None:
self._checks = {
SelectorType.UNIVERSAL: self._check_universal,
SelectorType.TYPE: self._check_type,
SelectorType.CLASS: self._check_class,
SelectorType.ID: self._check_id,
}

def _add_pseudo_class(self, pseudo_class: str) -> None:
"""Adds a pseudo class and updates specificity.
Args:
pseudo_class: Name of pseudo class.
"""
self.pseudo_classes.append(pseudo_class)
self.pseudo_classes.add(pseudo_class)
specificity1, specificity2, specificity3 = self.specificity
self.specificity = (specificity1, specificity2 + 1, specificity3)

Expand All @@ -101,31 +160,11 @@ def check(self, node: DOMNode) -> bool:
Returns:
True if the selector matches, otherwise False.
"""
return self._checks[self.type](node)

def _check_universal(self, node: DOMNode) -> bool:
return node.has_pseudo_class(*self.pseudo_classes)

def _check_type(self, node: DOMNode) -> bool:
if self.name not in node._css_type_names:
return False
if self.pseudo_classes and not node.has_pseudo_class(*self.pseudo_classes):
return False
return True

def _check_class(self, node: DOMNode) -> bool:
if not node.has_class(self.name):
return False
if self.pseudo_classes and not node.has_pseudo_class(*self.pseudo_classes):
return False
return True

def _check_id(self, node: DOMNode) -> bool:
if node.id != self.name:
return False
if self.pseudo_classes and not node.has_pseudo_class(*self.pseudo_classes):
return False
return True
return self._check(node) and (
node.has_pseudo_classes(self.pseudo_classes)
if self.pseudo_classes
else True
)


@dataclass
Expand Down
8 changes: 3 additions & 5 deletions src/textual/css/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,9 @@ def combine_selectors(
nested_selector = selectors2[0]
merged_selector = dataclasses.replace(
final_selector,
pseudo_classes=list(
set(
final_selector.pseudo_classes
+ nested_selector.pseudo_classes
)
pseudo_classes=(
final_selector.pseudo_classes
| nested_selector.pseudo_classes
),
specificity=_add_specificity(
final_selector.specificity, nested_selector.specificity
Expand Down
23 changes: 21 additions & 2 deletions src/textual/css/stylesheet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@


class StylesheetParseError(StylesheetError):
"""Raised when the stylesheet could not be parsed."""

def __init__(self, errors: StylesheetErrors) -> None:
self.errors = errors

Expand All @@ -38,6 +40,8 @@ def __rich__(self) -> RenderableType:


class StylesheetErrors:
"""A renderable for stylesheet errors."""

def __init__(self, rules: list[RuleSet]) -> None:
self.rules = rules
self.variables: dict[str, str] = {}
Expand Down Expand Up @@ -134,6 +138,8 @@ class CssSource(NamedTuple):

@rich.repr.auto(angular=True)
class Stylesheet:
"""A Stylsheet generated from Textual CSS."""

def __init__(self, *, variables: dict[str, str] | None = None) -> None:
self._rules: list[RuleSet] = []
self._rules_map: dict[str, list[RuleSet]] | None = None
Expand Down Expand Up @@ -183,6 +189,10 @@ def rules_map(self) -> dict[str, list[RuleSet]]:

@property
def css(self) -> str:
"""The equivalent TCSS for this stylesheet.
Note that this may not produce the same content as the file(s) used to generate the stylesheet.
"""
return "\n\n".join(rule_set.css for rule_set in self.rules)

def copy(self) -> Stylesheet:
Expand Down Expand Up @@ -407,9 +417,18 @@ def reparse(self) -> None:

@classmethod
def _check_rule(
cls, rule: RuleSet, css_path_nodes: list[DOMNode]
cls, rule_set: RuleSet, css_path_nodes: list[DOMNode]
) -> Iterable[Specificity3]:
for selector_set in rule.selector_set:
"""Check a rule set, return specificity of applicable rules.
Args:
rule_set: A rule set.
css_path_nodes: A list of the nodes from the App to the node being checked.
Yields:
Specificity of any matching selectors.
"""
for selector_set in rule_set.selector_set:
if _check_selectors(selector_set.selectors, css_path_nodes):
yield selector_set.specificity

Expand Down
27 changes: 18 additions & 9 deletions src/textual/dom.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __init__(
id: str | None = None,
classes: str | None = None,
) -> None:
self._classes = set()
self._classes: set[str] = set()
self._name = name
self._id = None
if id is not None:
Expand Down Expand Up @@ -575,8 +575,7 @@ def css_identifier_styled(self) -> Text:
@property
def pseudo_classes(self) -> frozenset[str]:
"""A (frozen) set of all pseudo classes."""
pseudo_classes = frozenset(self.get_pseudo_classes())
return pseudo_classes
return frozenset(self.get_pseudo_classes())

@property
def css_path_nodes(self) -> list[DOMNode]:
Expand Down Expand Up @@ -1276,17 +1275,27 @@ def toggle_class(self, *class_names: str) -> Self:
self._update_styles()
return self

def has_pseudo_class(self, *class_names: str) -> bool:
"""Check for pseudo classes (such as hover, focus etc)
def has_pseudo_class(self, class_name: str) -> bool:
"""Check the node has the given pseudo class.
Args:
*class_names: The pseudo classes to check for.
class_name: The pseudo class to check for.
Returns:
`True` if the DOM node has those pseudo classes, `False` if not.
`True` if the DOM node has the pseudo class, `False` if not.
"""
has_pseudo_classes = self.pseudo_classes.issuperset(class_names)
return has_pseudo_classes
return class_name in self.get_pseudo_classes()

def has_pseudo_classes(self, class_names: set[str]) -> bool:
"""Check the node has all the given pseudo classes.
Args:
class_names: Set of class names to check for.
Returns:
`True` if all pseudo class names are present.
"""
return class_names.issubset(self.get_pseudo_classes())

def refresh(self, *, repaint: bool = True, layout: bool = False) -> Self:
return self
16 changes: 8 additions & 8 deletions src/textual/widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -2876,14 +2876,6 @@ def get_pseudo_classes(self) -> Iterable[str]:
Returns:
Names of the pseudo classes.
"""
node: MessagePump | None = self
while isinstance(node, Widget):
if node.disabled:
yield "disabled"
break
node = node._parent
else:
yield "enabled"
if self.mouse_over:
yield "hover"
if self.has_focus:
Expand All @@ -2892,6 +2884,14 @@ def get_pseudo_classes(self) -> Iterable[str]:
yield "blur"
if self.can_focus:
yield "can-focus"
node: MessagePump | None = self
while isinstance(node, Widget):
if node.disabled:
yield "disabled"
break
node = node._parent
else:
yield "enabled"
try:
focused = self.screen.focused
except NoScreen:
Expand Down

0 comments on commit 8f822ae

Please sign in to comment.