diff --git a/CHANGELOG.md b/CHANGELOG.md index d583eaee32..5baf8b8e7e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Add Document `get_index_from_location` / `get_location_from_index` https://github.com/Textualize/textual/pull/3410 - Add setter for `TextArea.text` https://github.com/Textualize/textual/discussions/3525 +- Added `key` argument to the `DataTable.sort()` method, allowing the table to be sorted using a custom function (or other callable) https://github.com/Textualize/textual/pull/3090 +- Added `initial` to all css rules, which restores default (i.e. value from DEFAULT_CSS) https://github.com/Textualize/textual/pull/3566 +- Added HorizontalPad to pad.py https://github.com/Textualize/textual/pull/3571 +- Added `AwaitComplete` class, to be used for optionally awaitable return values https://github.com/Textualize/textual/pull/3498 + ### Changed @@ -46,15 +51,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Breaking change: empty rules now result in an error https://github.com/Textualize/textual/pull/3566 - Improved startup time by caching CSS parsing https://github.com/Textualize/textual/pull/3575 -### Added - -- Added `initial` to all css rules, which restores default (i.e. value from DEFAULT_CSS) https://github.com/Textualize/textual/pull/3566 -- Added HorizontalPad to pad.py https://github.com/Textualize/textual/pull/3571 - -### Added - -- Added `AwaitComplete` class, to be used for optionally awaitable return values https://github.com/Textualize/textual/pull/3498 - ## [0.40.0] - 2023-10-11 ### Added @@ -248,7 +244,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - DescendantBlur and DescendantFocus can now be used with @on decorator - ## [0.32.0] - 2023-08-03 ### Added diff --git a/docs/examples/widgets/data_table_sort.py b/docs/examples/widgets/data_table_sort.py new file mode 100644 index 0000000000..599a629394 --- /dev/null +++ b/docs/examples/widgets/data_table_sort.py @@ -0,0 +1,92 @@ +from rich.text import Text + +from textual.app import App, ComposeResult +from textual.widgets import DataTable, Footer + +ROWS = [ + ("lane", "swimmer", "country", "time 1", "time 2"), + (4, "Joseph Schooling", Text("Singapore", style="italic"), 50.39, 51.84), + (2, "Michael Phelps", Text("United States", style="italic"), 50.39, 51.84), + (5, "Chad le Clos", Text("South Africa", style="italic"), 51.14, 51.73), + (6, "László Cseh", Text("Hungary", style="italic"), 51.14, 51.58), + (3, "Li Zhuhao", Text("China", style="italic"), 51.26, 51.26), + (8, "Mehdy Metella", Text("France", style="italic"), 51.58, 52.15), + (7, "Tom Shields", Text("United States", style="italic"), 51.73, 51.12), + (1, "Aleksandr Sadovnikov", Text("Russia", style="italic"), 51.84, 50.85), + (10, "Darren Burns", Text("Scotland", style="italic"), 51.84, 51.55), +] + + +class TableApp(App): + BINDINGS = [ + ("a", "sort_by_average_time", "Sort By Average Time"), + ("n", "sort_by_last_name", "Sort By Last Name"), + ("c", "sort_by_country", "Sort By Country"), + ("d", "sort_by_columns", "Sort By Columns (Only)"), + ] + + current_sorts: set = set() + + def compose(self) -> ComposeResult: + yield DataTable() + yield Footer() + + def on_mount(self) -> None: + table = self.query_one(DataTable) + for col in ROWS[0]: + table.add_column(col, key=col) + table.add_rows(ROWS[1:]) + + def sort_reverse(self, sort_type: str): + """Determine if `sort_type` is ascending or descending.""" + reverse = sort_type in self.current_sorts + if reverse: + self.current_sorts.remove(sort_type) + else: + self.current_sorts.add(sort_type) + return reverse + + def action_sort_by_average_time(self) -> None: + """Sort DataTable by average of times (via a function) and + passing of column data through positional arguments.""" + + def sort_by_average_time_then_last_name(row_data): + name, *scores = row_data + return (sum(scores) / len(scores), name.split()[-1]) + + table = self.query_one(DataTable) + table.sort( + "swimmer", + "time 1", + "time 2", + key=sort_by_average_time_then_last_name, + reverse=self.sort_reverse("time"), + ) + + def action_sort_by_last_name(self) -> None: + """Sort DataTable by last name of swimmer (via a lambda).""" + table = self.query_one(DataTable) + table.sort( + "swimmer", + key=lambda swimmer: swimmer.split()[-1], + reverse=self.sort_reverse("swimmer"), + ) + + def action_sort_by_country(self) -> None: + """Sort DataTable by country which is a `Rich.Text` object.""" + table = self.query_one(DataTable) + table.sort( + "country", + key=lambda country: country.plain, + reverse=self.sort_reverse("country"), + ) + + def action_sort_by_columns(self) -> None: + """Sort DataTable without a key.""" + table = self.query_one(DataTable) + table.sort("swimmer", "lane", reverse=self.sort_reverse("columns")) + + +app = TableApp() +if __name__ == "__main__": + app.run() diff --git a/docs/widgets/data_table.md b/docs/widgets/data_table.md index a676c4dca1..ab1981c0f1 100644 --- a/docs/widgets/data_table.md +++ b/docs/widgets/data_table.md @@ -143,11 +143,22 @@ visible as you scroll through the data table. ### Sorting -The `DataTable` can be sorted using the [sort][textual.widgets.DataTable.sort] method. -In order to sort your data by a column, you must have supplied a `key` to the `add_column` method -when you added it. -You can then pass this key to the `sort` method to sort by that column. -Additionally, you can sort by multiple columns by passing multiple keys to `sort`. +The `DataTable` can be sorted using the [sort][textual.widgets.DataTable.sort] method. In order to sort your data by a column, you can provide the `key` you supplied to the `add_column` method or a `ColumnKey`. You can then pass one more column keys to the `sort` method to sort by one or more columns. + +Additionally, you can sort your `DataTable` with a custom function (or other callable) via the `key` argument. Similar to the `key` parameter of the built-in [sorted()](https://docs.python.org/3/library/functions.html#sorted) function, your function (or other callable) should take a single argument (row) and return a key to use for sorting purposes. + +Providing both `columns` and `key` will limit the row information sent to your `key` function (or other callable) to only the columns specified. + +=== "Output" + + ```{.textual path="docs/examples/widgets/data_table_sort.py"} + ``` + +=== "data_table_sort.py" + + ```python + --8<-- "docs/examples/widgets/data_table_sort.py" + ``` ### Labelled rows diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 4b482424f9..54c35a63a3 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from itertools import chain, zip_longest from operator import itemgetter -from typing import Any, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast +from typing import Any, Callable, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast import rich.repr from rich.console import RenderableType @@ -2348,30 +2348,40 @@ def _get_fixed_offset(self) -> Spacing: def sort( self, *columns: ColumnKey | str, + key: Callable[[Any], Any] | None = None, reverse: bool = False, ) -> Self: - """Sort the rows in the `DataTable` by one or more column keys. + """Sort the rows in the `DataTable` by one or more column keys or a + key function (or other callable). If both columns and a key function + are specified, only data from those columns will sent to the key function. Args: columns: One or more columns to sort by the values in. + key: A function (or other callable) that returns a key to + use for sorting purposes. reverse: If True, the sort order will be reversed. Returns: The `DataTable` instance. """ - def sort_by_column_keys( - row: tuple[RowKey, dict[ColumnKey | str, CellType]] - ) -> Any: + def key_wrapper(row: tuple[RowKey, dict[ColumnKey | str, CellType]]) -> Any: _, row_data = row - result = itemgetter(*columns)(row_data) + if columns: + result = itemgetter(*columns)(row_data) + else: + result = tuple(row_data.values()) + if key is not None: + return key(result) return result ordered_rows = sorted( - self._data.items(), key=sort_by_column_keys, reverse=reverse + self._data.items(), + key=key_wrapper, + reverse=reverse, ) self._row_locations = TwoWayDict( - {key: new_index for new_index, (key, _) in enumerate(ordered_rows)} + {row_key: new_index for new_index, (row_key, _) in enumerate(ordered_rows)} ) self._update_count += 1 self.refresh() diff --git a/tests/test_data_table.py b/tests/test_data_table.py index 15ded2563e..4f473f978f 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -1197,6 +1197,100 @@ async def test_unset_hover_highlight_when_no_table_cell_under_mouse(): assert not table._show_hover_cursor +async def test_sort_by_all_columns_no_key(): + """Test sorting a `DataTable` by all columns.""" + + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + a, b, c = table.add_columns("A", "B", "C") + table.add_row(1, 3, 8) + table.add_row(2, 9, 5) + table.add_row(1, 1, 9) + assert table.get_row_at(0) == [1, 3, 8] + assert table.get_row_at(1) == [2, 9, 5] + assert table.get_row_at(2) == [1, 1, 9] + + table.sort() + assert table.get_row_at(0) == [1, 1, 9] + assert table.get_row_at(1) == [1, 3, 8] + assert table.get_row_at(2) == [2, 9, 5] + + table.sort(reverse=True) + assert table.get_row_at(0) == [2, 9, 5] + assert table.get_row_at(1) == [1, 3, 8] + assert table.get_row_at(2) == [1, 1, 9] + + +async def test_sort_by_multiple_columns_no_key(): + """Test sorting a `DataTable` by multiple columns.""" + + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + a, b, c = table.add_columns("A", "B", "C") + table.add_row(1, 3, 8) + table.add_row(2, 9, 5) + table.add_row(1, 1, 9) + + table.sort(a, b, c) + assert table.get_row_at(0) == [1, 1, 9] + assert table.get_row_at(1) == [1, 3, 8] + assert table.get_row_at(2) == [2, 9, 5] + + table.sort(a, c, b) + assert table.get_row_at(0) == [1, 3, 8] + assert table.get_row_at(1) == [1, 1, 9] + assert table.get_row_at(2) == [2, 9, 5] + + table.sort(c, a, b, reverse=True) + assert table.get_row_at(0) == [1, 1, 9] + assert table.get_row_at(1) == [1, 3, 8] + assert table.get_row_at(2) == [2, 9, 5] + + table.sort(a, c) + assert table.get_row_at(0) == [1, 3, 8] + assert table.get_row_at(1) == [1, 1, 9] + assert table.get_row_at(2) == [2, 9, 5] + + +async def test_sort_by_function_sum(): + """Test sorting a `DataTable` using a custom sort function.""" + + def custom_sort(row_data): + return sum(row_data) + + row_data = ( + [1, 3, 8], # SUM=12 + [2, 9, 5], # SUM=16 + [1, 1, 9], # SUM=11 + ) + + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + a, b, c = table.add_columns("A", "B", "C") + for i, row in enumerate(row_data): + table.add_row(*row) + + # Sorting by all columns + table.sort(a, b, c, key=custom_sort) + sorted_row_data = sorted(row_data, key=sum) + for i, row in enumerate(sorted_row_data): + assert table.get_row_at(i) == row + + # Passing a sort function but no columns also sorts by all columns + table.sort(key=custom_sort) + sorted_row_data = sorted(row_data, key=sum) + for i, row in enumerate(sorted_row_data): + assert table.get_row_at(i) == row + + table.sort(a, b, c, key=custom_sort, reverse=True) + sorted_row_data = sorted(row_data, key=sum, reverse=True) + for i, row in enumerate(sorted_row_data): + assert table.get_row_at(i) == row + + @pytest.mark.parametrize( ["cell", "height"], [