diff --git a/CHANGELOG.md b/CHANGELOG.md index fc8b457d85..54113f04e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Changed - The DataTable cursor is now scrolled into view when the cursor coordinate is changed programmatically https://github.com/Textualize/textual/issues/2459 +- Added `key_function` as an optional argument for `DataTable.sort` https://github.com/Textualize/textual/pull/2512 - run_worker exclusive parameter is now `False` by default https://github.com/Textualize/textual/pull/2470 - Added `always_update` as an optional argument for `reactive.var` - Made Binding description default to empty string, which is equivalent to show=False https://github.com/Textualize/textual/pull/2501 diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 0f40552be0..b1aa007ad6 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from itertools import chain, zip_longest from operator import itemgetter -from typing import Any, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast import rich.repr from rich.console import RenderableType @@ -13,7 +13,7 @@ from rich.segment import Segment from rich.style import Style from rich.text import Text, TextType -from typing_extensions import Literal, Self, TypeAlias +from typing_extensions import Literal from .. import events from .._cache import LRUCache @@ -32,6 +32,10 @@ from ..strip import Strip from ..widget import PseudoClasses +if TYPE_CHECKING: + from _typeshed import SupportsRichComparison + from typing_extensions import Self, TypeAlias + CellCacheKey: TypeAlias = ( "tuple[RowKey, ColumnKey, Style, bool, bool, int, PseudoClasses]" ) @@ -1981,26 +1985,33 @@ def sort( self, *columns: ColumnKey | str, reverse: bool = False, + key_function: Callable[[tuple[RowKey, dict[ColumnKey, CellType]]], + SupportsRichComparison] | None = None, ) -> Self: """Sort the rows in the `DataTable` by one or more column keys. Args: - columns: One or more columns to sort by the values in. + columns: One or more columns to sort by (unless key_function is set). reverse: If True, the sort order will be reversed. + key_function: A custom function to extract a comparison key for each row. Returns: The `DataTable` instance. + """ def sort_by_column_keys( - row: tuple[RowKey, dict[ColumnKey | str, CellType]] - ) -> Any: + row: tuple[RowKey, dict[ColumnKey, CellType]] + ) -> SupportsRichComparison: _, row_data = row result = itemgetter(*columns)(row_data) return result + if key_function is None: + key_function = sort_by_column_keys + ordered_rows = sorted( - self._data.items(), key=sort_by_column_keys, reverse=reverse + self._data.items(), key=key_function, reverse=reverse ) self._row_locations = TwoWayDict( {key: new_index for new_index, (key, _) in enumerate(ordered_rows)}