diff --git a/CHANGELOG.md b/CHANGELOG.md index 38a6cb0c2b..c1666869d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Change default quit key to `ctrl+q` https://github.com/Textualize/textual/pull/5352 - Changed delete line binding on TextArea to use `ctrl+shift+x` https://github.com/Textualize/textual/pull/5352 - The command palette will now select the top item automatically https://github.com/Textualize/textual/pull/5361 +- Implemented a better matching algorithm for the command palette https://github.com/Textualize/textual/pull/5365 ### Fixed diff --git a/src/textual/containers.py b/src/textual/containers.py index 18a46d1455..a6f655a177 100644 --- a/src/textual/containers.py +++ b/src/textual/containers.py @@ -267,7 +267,7 @@ def __init__( stretch_height: bool = True, regular: bool = False, ) -> None: - """Initialize a Widget. + """ Args: *children: Child widgets. diff --git a/src/textual/fuzzy.py b/src/textual/fuzzy.py index 1f88feb742..337ad29b4d 100644 --- a/src/textual/fuzzy.py +++ b/src/textual/fuzzy.py @@ -7,13 +7,151 @@ from __future__ import annotations -from re import IGNORECASE, compile, escape +from operator import itemgetter +from re import IGNORECASE, escape, finditer, search +from typing import Iterable, NamedTuple import rich.repr from rich.style import Style from rich.text import Text -from textual.cache import LRUCache + +class _Search(NamedTuple): + """Internal structure to keep track of a recursive search.""" + + candidate_offset: int = 0 + query_offset: int = 0 + offsets: tuple[int, ...] = () + + def branch(self, offset: int) -> tuple[_Search, _Search]: + """Branch this search when an offset is found. + + Args: + offset: Offset of a matching letter in the query. + + Returns: + A pair of search objects. + """ + _, query_offset, offsets = self + return ( + _Search(offset + 1, query_offset + 1, offsets + (offset,)), + _Search(offset + 1, query_offset, offsets), + ) + + @property + def groups(self) -> int: + """Number of groups in offsets.""" + groups = 1 + last_offset = self.offsets[0] + for offset in self.offsets[1:]: + if offset != last_offset + 1: + groups += 1 + last_offset = offset + return groups + + +class FuzzySearch: + """Performs a fuzzy search. + + Unlike a regex solution, this will finds all possible matches. + """ + + def __init__(self, case_sensitive: bool = False) -> None: + """Initialize fuzzy search. + + Args: + case_sensitive: Is the match case sensitive? + """ + self.cache: dict[tuple[str, str, bool], tuple[float, tuple[int, ...]]] = {} + self.case_sensitive = case_sensitive + + def match(self, query: str, candidate: str) -> tuple[float, tuple[int, ...]]: + """Match against a query. + + Args: + query: The fuzzy query. + candidate: A candidate to check,. + + Returns: + A pair of (score, tuple of offsets). `(0, ())` for no result. + """ + + query_regex = ".*?".join(f"({escape(character)})" for character in query) + if not search( + query_regex, candidate, flags=0 if self.case_sensitive else IGNORECASE + ): + # Bail out early if there is no possibility of a match + return (0.0, ()) + + cache_key = (query, candidate, self.case_sensitive) + if cache_key in self.cache: + return self.cache[cache_key] + result = max( + self._match(query, candidate), key=itemgetter(0), default=(0.0, ()) + ) + self.cache[cache_key] = result + return result + + def _match( + self, query: str, candidate: str + ) -> Iterable[tuple[float, tuple[int, ...]]]: + """Generator to do the matching. + + Args: + query: Query to match. + candidate: Candidate to check against. + + Yields: + Pairs of score and tuple of offsets. + """ + if not self.case_sensitive: + query = query.lower() + candidate = candidate.lower() + + # We need this to give a bonus to first letters. + first_letters = {match.start() for match in finditer(r"\w+", candidate)} + + def score(search: _Search) -> float: + """Sore a search. + + Args: + search: Search object. + + Returns: + Score. + + """ + # This is a heuristic, and can be tweaked for better results + # Boost first letter matches + score: float = sum( + (2.0 if offset in first_letters else 1.0) for offset in search.offsets + ) + # Boost to favor less groups + offset_count = len(search.offsets) + normalized_groups = (offset_count - (search.groups - 1)) / offset_count + score *= 1 + (normalized_groups**2) + return score + + stack: list[_Search] = [_Search()] + push = stack.append + pop = stack.pop + query_size = len(query) + find = candidate.find + # Limit the number of loops out of an abundance of caution. + # This would be hard to reach without contrived data. + remaining_loops = 200 + + while stack and (remaining_loops := remaining_loops - 1): + search = pop() + offset = find(query[search.query_offset], search.candidate_offset) + if offset != -1: + advance_branch, branch = search.branch(offset) + if advance_branch.query_offset == query_size: + yield score(advance_branch), advance_branch.offsets + push(branch) + else: + push(advance_branch) + push(branch) @rich.repr.auto @@ -36,11 +174,8 @@ def __init__( """ self._query = query self._match_style = Style(reverse=True) if match_style is None else match_style - self._query_regex = compile( - ".*?".join(f"({escape(character)})" for character in query), - flags=0 if case_sensitive else IGNORECASE, - ) - self._cache: LRUCache[str, float] = LRUCache(1024 * 4) + self._case_sensitive = case_sensitive + self.fuzzy_search = FuzzySearch() @property def query(self) -> str: @@ -52,15 +187,10 @@ def match_style(self) -> Style: """The style that will be used to highlight hits in the matched text.""" return self._match_style - @property - def query_pattern(self) -> str: - """The regular expression pattern built from the query.""" - return self._query_regex.pattern - @property def case_sensitive(self) -> bool: """Is this matcher case sensitive?""" - return not bool(self._query_regex.flags & IGNORECASE) + return self._case_sensitive def match(self, candidate: str) -> float: """Match the candidate against the query. @@ -71,27 +201,7 @@ def match(self, candidate: str) -> float: Returns: Strength of the match from 0 to 1. """ - cached = self._cache.get(candidate) - if cached is not None: - return cached - match = self._query_regex.search(candidate) - if match is None: - score = 0.0 - else: - assert match.lastindex is not None - offsets = [ - match.span(group_no)[0] for group_no in range(1, match.lastindex + 1) - ] - group_count = 0 - last_offset = -2 - for offset in offsets: - if offset > last_offset + 1: - group_count += 1 - last_offset = offset - - score = 1.0 - ((group_count - 1) / len(candidate)) - self._cache[candidate] = score - return score + return self.fuzzy_search.match(self.query, candidate)[0] def highlight(self, candidate: str) -> Text: """Highlight the candidate with the fuzzy match. @@ -102,20 +212,11 @@ def highlight(self, candidate: str) -> Text: Returns: A [rich.text.Text][`Text`] object with highlighted matches. """ - match = self._query_regex.search(candidate) text = Text.from_markup(candidate) - if match is None: + score, offsets = self.fuzzy_search.match(self.query, candidate) + if not score: return text - assert match.lastindex is not None - if self._query in text.plain: - # Favor complete matches - offset = text.plain.index(self._query) - text.stylize(self._match_style, offset, offset + len(self._query)) - else: - offsets = [ - match.span(group_no)[0] for group_no in range(1, match.lastindex + 1) - ] - for offset in offsets: + for offset in offsets: + if not candidate[offset].isspace(): text.stylize(self._match_style, offset, offset + 1) - return text diff --git a/tests/snapshot_tests/test_snapshots.py b/tests/snapshot_tests/test_snapshots.py index a6723b8be2..c3c9791068 100644 --- a/tests/snapshot_tests/test_snapshots.py +++ b/tests/snapshot_tests/test_snapshots.py @@ -1510,7 +1510,7 @@ def test_example_color_command(snap_compare): """Test the color_command example.""" assert snap_compare( EXAMPLES_DIR / "color_command.py", - press=[App.COMMAND_PALETTE_BINDING, "r", "e", "d", "down", "enter"], + press=[App.COMMAND_PALETTE_BINDING, "r", "e", "d", "enter"], ) diff --git a/tests/test_fuzzy.py b/tests/test_fuzzy.py index d2ab460c9a..dc3c8ccd92 100644 --- a/tests/test_fuzzy.py +++ b/tests/test_fuzzy.py @@ -4,25 +4,24 @@ from textual.fuzzy import Matcher -def test_match(): - matcher = Matcher("foo.bar") +def test_no_match(): + """Check non matching score of zero.""" + matcher = Matcher("x") + assert matcher.match("foo") == 0 + - # No match - assert matcher.match("egg") == 0 - assert matcher.match("") == 0 +def test_match_single_group(): + """Check that single groups rang higher.""" + matcher = Matcher("abc") + assert matcher.match("foo abc bar") > matcher.match("fooa barc") - # Perfect match - assert matcher.match("foo.bar") == 1.0 - # Perfect match (with superfluous characters) - assert matcher.match("foo.bar sdf") == 1.0 - assert matcher.match("xz foo.bar sdf") == 1.0 - # Partial matches - # 2 Groups - assert matcher.match("foo egg.bar") == 1.0 - 1 / 11 +def test_boosted_matches(): + """Check first word matchers rank higher.""" + matcher = Matcher("ss") - # 3 Groups - assert matcher.match("foo .ba egg r") == 1.0 - 2 / 13 + # First word matchers should score higher + assert matcher.match("Save Screenshot") > matcher.match("Show Keys abcde") def test_highlight():