From fcce7e2e7baf180c673796274ab6fcc909ef9312 Mon Sep 17 00:00:00 2001 From: Josh Duncan <44387852+joshbduncan@users.noreply.github.com> Date: Fri, 11 Aug 2023 00:10:46 -0400 Subject: [PATCH 01/10] DataTable sort by function (or other callable) The `DataTable` widget now takes the `by` argument instead of `columns`, allowing the table to also be sorted using a custom function (or other callable). This is a breaking change since it requires all calls to the `sort` method to include an iterable of key(s) (or a singular function/callable). Covers #2261 using [suggested function signature](https://github.com/Textualize/textual/pull/2512#issuecomment-1580277771) from @darrenburns on PR #2512. --- CHANGELOG.md | 5 +- docs/examples/widgets/data_table_sort.py | 78 ++++++++ docs/widgets/data_table.md | 19 +- src/textual/widgets/_data_table.py | 17 +- .../snapshot_apps/data_table_sort.py | 2 +- tests/test_data_table.py | 180 +++++++++++++++--- 6 files changed, 263 insertions(+), 38 deletions(-) create mode 100644 docs/examples/widgets/data_table_sort.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 35f819b53a..7b486cfad2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,9 +16,12 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - `MouseMove` events bubble up from widgets. `App` and `Screen` receive `MouseMove` events even if there's no Widget under the cursor. https://github.com/Textualize/textual/issues/2905 ### Added -- Added an interface for replacing prompt of an individual option in an `OptionList` https://github.com/Textualize/textual/issues/2603 +- Added an interface for replacing prompt of an individual option in an `OptionList` https://github.com/Textualize/textual/issues/2603 - Added `DirectoryTree.reload_node` method https://github.com/Textualize/textual/issues/2757 +### Changed +- The `DataTable` widget now takes the `by` argument instead of `columns`, allowing the table to also be sorted using a custom function (or other callable). This is a breaking change since it requires all calls to the `sort` method to include an iterable of key(s) (or a singular function/callable). https://github.com/Textualize/textual/issues/2261 + ## [0.32.0] - 2023-08-03 ### Added diff --git a/docs/examples/widgets/data_table_sort.py b/docs/examples/widgets/data_table_sort.py new file mode 100644 index 0000000000..d6afea74fa --- /dev/null +++ b/docs/examples/widgets/data_table_sort.py @@ -0,0 +1,78 @@ +from rich.text import Text + +from textual.app import App, ComposeResult +from textual.widgets import DataTable, Footer + +ROWS = [ + ("lane", "swimmer", "country", "time 1", "time 2"), + (4, "Joseph Schooling", Text("Singapore", style="italic"), 50.39, 51.84), + (2, "Michael Phelps", Text("United States", style="italic"), 51.14, 51.84), + (5, "Chad le Clos", Text("South Africa", style="italic"), 51.14, 51.73), + (6, "László Cseh", Text("Hungary", style="italic"), 51.14, 51.58), + (3, "Li Zhuhao", Text("China", style="italic"), 51.26, 51.26), + (8, "Mehdy Metella", Text("France", style="italic"), 51.58, 52.15), + (7, "Tom Shields", Text("United States", style="italic"), 51.73, 51.12), + (1, "Aleksandr Sadovnikov", Text("Russia", style="italic"), 51.84, 50.85), + (10, "Darren Burns", Text("Scotland", style="italic"), 51.84, 51.55), +] + + +class TableApp(App): + BINDINGS = [ + ("a", "sort_by_average_time", "Sort By Average Time"), + ("n", "sort_by_last_name", "Sort By Last Name"), + ("c", "sort_by_country", "Sort By Country"), + ] + + current_sorts: set = set() + + def compose(self) -> ComposeResult: + yield DataTable() + yield Footer() + + def on_mount(self) -> None: + table = self.query_one(DataTable) + for col in ROWS[0]: + table.add_column(col, key=col) + table.add_rows(ROWS[1:]) + + def sort_reverse(self, sort_type: str): + """Determine if `sort_type` is ascending or descending.""" + reverse = sort_type in self.current_sorts + if reverse: + self.current_sorts.remove(sort_type) + else: + self.current_sorts.add(sort_type) + return reverse + + def action_sort_by_average_time(self) -> None: + """Sort DataTable by average of times (via a function).""" + + def sort_by_average_time(row): + _, row_data = row + times = [row_data["time 1"], row_data["time 2"]] + return sum(times) / len(times) + + table = self.query_one(DataTable) + table.sort(sort_by_average_time, reverse=self.sort_reverse("time")) + + def action_sort_by_last_name(self) -> None: + """Sort DataTable by last name of swimmer (via a lambda).""" + table = self.query_one(DataTable) + table.sort( + lambda row: row[1]["swimmer"].split()[-1], + reverse=self.sort_reverse("swimmer"), + ) + + def action_sort_by_country(self) -> None: + """Sort DataTable by country which is a `Rich.Text` object.""" + table = self.query_one(DataTable) + table.sort( + lambda row: row[1]["country"].plain, + reverse=self.sort_reverse("country"), + ) + + +app = TableApp() +if __name__ == "__main__": + app.run() diff --git a/docs/widgets/data_table.md b/docs/widgets/data_table.md index 0ae59829b6..46af1bfca1 100644 --- a/docs/widgets/data_table.md +++ b/docs/widgets/data_table.md @@ -143,11 +143,20 @@ visible as you scroll through the data table. ### Sorting -The `DataTable` can be sorted using the [sort][textual.widgets.DataTable.sort] method. -In order to sort your data by a column, you must have supplied a `key` to the `add_column` method -when you added it. -You can then pass this key to the `sort` method to sort by that column. -Additionally, you can sort by multiple columns by passing multiple keys to `sort`. +The `DataTable` can be sorted using the [sort][textual.widgets.DataTable.sort] method. In order to sort your data by a column, you can provide the `key` you supplied to the `add_column` method or a `ColumnKey`. You can then pass this key to the `sort` method inside of an iterable. To sort by multiple columns, pass an iterable with multiple keys to `sort`. + +Additionally, you can sort your `DataTable` with a custom function (or other callable). Similar to the `key` parameter of the built-in [sorted()](https://docs.python.org/3/library/functions.html#sorted) function, your function (or other callable) should take a single argument (row) and returns a key to use for sorting purposes. + +=== "Output" + + ```{.textual path="docs/examples/widgets/data_table_sort.py"} + ``` + +=== "data_table_sort.py" + + ```python + --8<-- "docs/examples/widgets/data_table_sort.py" + ``` ### Labelled rows diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index db3b2801d4..c6ea2b4958 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from itertools import chain, zip_longest from operator import itemgetter -from typing import Any, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast +from typing import Any, Callable, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast import rich.repr from rich.console import RenderableType @@ -2107,13 +2107,15 @@ def _get_fixed_offset(self) -> Spacing: def sort( self, - *columns: ColumnKey | str, + by: Iterable[ColumnKey | str] | Callable, reverse: bool = False, ) -> Self: - """Sort the rows in the `DataTable` by one or more column keys. + """Sort the rows in the `DataTable` by one or more column keys or a + key function (or other callable). Args: - columns: One or more columns to sort by the values in. + by: One or more columns to sort by the values by, or a function + (or other callable) that returns a key to use for sorting purposes. reverse: If True, the sort order will be reversed. Returns: @@ -2124,12 +2126,11 @@ def sort_by_column_keys( row: tuple[RowKey, dict[ColumnKey | str, CellType]] ) -> Any: _, row_data = row - result = itemgetter(*columns)(row_data) + result = itemgetter(*by)(row_data) return result - ordered_rows = sorted( - self._data.items(), key=sort_by_column_keys, reverse=reverse - ) + key = by if isinstance(by, Callable) else sort_by_column_keys + ordered_rows = sorted(self._data.items(), key=key, reverse=reverse) self._row_locations = TwoWayDict( {key: new_index for new_index, (key, _) in enumerate(ordered_rows)} ) diff --git a/tests/snapshot_tests/snapshot_apps/data_table_sort.py b/tests/snapshot_tests/snapshot_apps/data_table_sort.py index b4866e0ca5..dacf8f3a8d 100644 --- a/tests/snapshot_tests/snapshot_apps/data_table_sort.py +++ b/tests/snapshot_tests/snapshot_apps/data_table_sort.py @@ -36,7 +36,7 @@ def on_mount(self) -> None: def action_sort(self): table = self.query_one(DataTable) - table.sort("time", "lane") + table.sort(["time", "lane"]) app = TableApp() diff --git a/tests/test_data_table.py b/tests/test_data_table.py index a1c08edc7c..a92a33b467 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -1,5 +1,7 @@ from __future__ import annotations +from operator import itemgetter + import pytest from rich.text import Text @@ -417,11 +419,11 @@ async def test_get_cell_coordinate_returns_coordinate(): table.add_row("ValR2C1", "ValR2C2", "ValR2C3", key="R2") table.add_row("ValR3C1", "ValR3C2", "ValR3C3", key="R3") - assert table.get_cell_coordinate('R1', 'C1') == Coordinate(0, 0) - assert table.get_cell_coordinate('R2', 'C2') == Coordinate(1, 1) - assert table.get_cell_coordinate('R1', 'C3') == Coordinate(0, 2) - assert table.get_cell_coordinate('R3', 'C1') == Coordinate(2, 0) - assert table.get_cell_coordinate('R3', 'C2') == Coordinate(2, 1) + assert table.get_cell_coordinate("R1", "C1") == Coordinate(0, 0) + assert table.get_cell_coordinate("R2", "C2") == Coordinate(1, 1) + assert table.get_cell_coordinate("R1", "C3") == Coordinate(0, 2) + assert table.get_cell_coordinate("R3", "C1") == Coordinate(2, 0) + assert table.get_cell_coordinate("R3", "C2") == Coordinate(2, 1) async def test_get_cell_coordinate_invalid_row_key(): @@ -432,7 +434,7 @@ async def test_get_cell_coordinate_invalid_row_key(): table.add_row("TargetValue", key="R1") with pytest.raises(CellDoesNotExist): - coordinate = table.get_cell_coordinate('INVALID_ROW', 'C1') + coordinate = table.get_cell_coordinate("INVALID_ROW", "C1") async def test_get_cell_coordinate_invalid_column_key(): @@ -443,7 +445,7 @@ async def test_get_cell_coordinate_invalid_column_key(): table.add_row("TargetValue", key="R1") with pytest.raises(CellDoesNotExist): - coordinate = table.get_cell_coordinate('R1', 'INVALID_COLUMN') + coordinate = table.get_cell_coordinate("R1", "INVALID_COLUMN") async def test_get_cell_at_returns_value_at_cell(): @@ -476,7 +478,7 @@ async def test_get_row(): assert table.get_row(second_row) == [3, 2, 1] # Even if row positions change, keys should always refer to same rows. - table.sort(b) + table.sort([b]) assert table.get_row(first_row) == [2, 4, 1] assert table.get_row(second_row) == [3, 2, 1] @@ -500,7 +502,7 @@ async def test_get_row_at(): assert table.get_row_at(1) == [3, 2, 1] # If we sort, then the rows present at the indices *do* change! - table.sort(b) + table.sort([b]) # Since we sorted on column "B", the rows at indices 0 and 1 are swapped. assert table.get_row_at(0) == [3, 2, 1] @@ -529,9 +531,9 @@ async def test_get_row_index_returns_index(): table.add_row("ValR2C1", "ValR2C2", key="R2") table.add_row("ValR3C1", "ValR3C2", key="R3") - assert table.get_row_index('R1') == 0 - assert table.get_row_index('R2') == 1 - assert table.get_row_index('R3') == 2 + assert table.get_row_index("R1") == 0 + assert table.get_row_index("R2") == 1 + assert table.get_row_index("R3") == 2 async def test_get_row_index_invalid_row_key(): @@ -542,7 +544,7 @@ async def test_get_row_index_invalid_row_key(): table.add_row("TargetValue", key="R1") with pytest.raises(RowDoesNotExist): - index = table.get_row_index('InvalidRow') + index = table.get_row_index("InvalidRow") async def test_get_column(): @@ -589,6 +591,7 @@ async def test_get_column_at_invalid_index(index): with pytest.raises(ColumnDoesNotExist): list(table.get_column_at(index)) + async def test_get_column_index_returns_index(): app = DataTableApp() async with app.run_test(): @@ -596,12 +599,12 @@ async def test_get_column_index_returns_index(): table.add_column("Column1", key="C1") table.add_column("Column2", key="C2") table.add_column("Column3", key="C3") - table.add_row("ValR1C1", "ValR1C2", "ValR1C3", key="R1") - table.add_row("ValR2C1", "ValR2C2", "ValR2C3", key="R2") + table.add_row("ValR1C1", "ValR1C2", "ValR1C3", key="R1") + table.add_row("ValR2C1", "ValR2C2", "ValR2C3", key="R2") - assert table.get_column_index('C1') == 0 - assert table.get_column_index('C2') == 1 - assert table.get_column_index('C3') == 2 + assert table.get_column_index("C1") == 0 + assert table.get_column_index("C2") == 1 + assert table.get_column_index("C3") == 2 async def test_get_column_index_invalid_column_key(): @@ -611,11 +614,10 @@ async def test_get_column_index_invalid_column_key(): table.add_column("Column1", key="C1") table.add_column("Column2", key="C2") table.add_column("Column3", key="C3") - table.add_row("TargetValue1", "TargetValue2", "TargetValue3", key="R1") + table.add_row("TargetValue1", "TargetValue2", "TargetValue3", key="R1") with pytest.raises(ColumnDoesNotExist): - index = table.get_column_index('InvalidCol') - + index = table.get_column_index("InvalidCol") async def test_update_cell_cell_exists(): @@ -876,7 +878,7 @@ async def test_sort_coordinate_and_key_access(): assert table.get_cell_at(Coordinate(1, 0)) == 1 assert table.get_cell_at(Coordinate(2, 0)) == 2 - table.sort(column) + table.sort([column]) # The keys still refer to the same cells... assert table.get_cell(row_one, column) == 1 @@ -909,7 +911,7 @@ async def test_sort_reverse_coordinate_and_key_access(): assert table.get_cell_at(Coordinate(1, 0)) == 1 assert table.get_cell_at(Coordinate(2, 0)) == 2 - table.sort(column, reverse=True) + table.sort([column], reverse=True) # The keys still refer to the same cells... assert table.get_cell(row_one, column) == 1 @@ -1159,3 +1161,135 @@ async def test_unset_hover_highlight_when_no_table_cell_under_mouse(): # the widget, and the hover cursor is hidden await pilot.hover(DataTable, offset=Offset(42, 1)) assert not table._show_hover_cursor + + +async def test_sort_by_multiple_columns(): + """Test sorting a `DataTable` my multiple columns.""" + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + a, b, c = table.add_columns("A", "B", "C") + table.add_row(1, 3, 8) + table.add_row(2, 9, 5) + table.add_row(1, 1, 9) + assert table.get_row_at(0) == [1, 3, 8] + assert table.get_row_at(1) == [2, 9, 5] + assert table.get_row_at(2) == [1, 1, 9] + + table.sort([a, b, c]) + assert table.get_row_at(0) == [1, 1, 9] + assert table.get_row_at(1) == [1, 3, 8] + assert table.get_row_at(2) == [2, 9, 5] + + table.sort([a, c, b]) + assert table.get_row_at(0) == [1, 3, 8] + assert table.get_row_at(1) == [1, 1, 9] + assert table.get_row_at(2) == [2, 9, 5] + + table.sort([c, a, b], reverse=True) + assert table.get_row_at(0) == [1, 1, 9] + assert table.get_row_at(1) == [1, 3, 8] + assert table.get_row_at(2) == [2, 9, 5] + + table.sort([a, c]) + assert table.get_row_at(0) == [1, 3, 8] + assert table.get_row_at(1) == [1, 1, 9] + assert table.get_row_at(2) == [2, 9, 5] + + +async def test_sort_by_function_sum(): + """Test sorting a `DataTable` using a custom sort function.""" + + def custom_sort(row): + _, row_data = row + return sum(row_data.values()) + + row_data = ( + [1, 3, 8], # SUM=12 + [2, 9, 5], # SUM=16 + [1, 1, 9], # SUM=11 + ) + + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + a, b, c = table.add_columns("A", "B", "C") + + for i, row in enumerate(row_data): + table.add_row(*row) + assert table.get_row_at(i) == row + assert sum(table.get_row_at(i)) == sum(row) + + table.sort(custom_sort) + sorted_row_data = sorted(row_data, key=sum) + for i, row in enumerate(sorted_row_data): + assert table.get_row_at(i) == row + + table.sort(custom_sort, reverse=True) + sorted_row_data = sorted(row_data, key=sum, reverse=True) + for i, row in enumerate(sorted_row_data): + assert table.get_row_at(i) == row + + +async def test_sort_by_function_sum_lambda(): + """Test sorting a `DataTable` using a custom sort lambda.""" + row_data = ( + [1, 3, 8], # SUM=12 + [2, 9, 5], # SUM=16 + [1, 1, 9], # SUM=11 + ) + + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + a, b, c = table.add_columns("A", "B", "C") + + for i, row in enumerate(row_data): + table.add_row(*row) + assert table.get_row_at(i) == row + assert sum(table.get_row_at(i)) == sum(row) + + table.sort(lambda row: sum(row[1].values())) + sorted_row_data = sorted(row_data, key=sum) + for i, row in enumerate(sorted_row_data): + assert table.get_row_at(i) == row + + table.sort(lambda row: sum(row[1].values()), reverse=True) + sorted_row_data = sorted(row_data, key=sum, reverse=True) + for i, row in enumerate(sorted_row_data): + assert table.get_row_at(i) == row + + +async def test_sort_by_function_date(): + """Test sorting a `DataTable` by date when an abnormal date format is used. + Based on a question in the Discourse 'Help Wanted' section. + + Example: 26/07/23 == 2023-07-26 + """ + + def custom_sort(row): + _, row_data = row + result = itemgetter(birthdate)(row_data) + d, m, y = result.split("/") + return f"{y}{m}{d}" + + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + _, birthdate = table.add_columns("Name", "Birthdate") + table.add_row("Doug Johnson", "26/07/23") + table.add_row("Albert Douglass", "16/07/23") + table.add_row("Jane Doe", "25/12/22") + assert table.get_row_at(0) == ["Doug Johnson", "26/07/23"] + assert table.get_row_at(1) == ["Albert Douglass", "16/07/23"] + assert table.get_row_at(2) == ["Jane Doe", "25/12/22"] + + table.sort(custom_sort) + assert table.get_row_at(0) == ["Jane Doe", "25/12/22"] + assert table.get_row_at(1) == ["Albert Douglass", "16/07/23"] + assert table.get_row_at(2) == ["Doug Johnson", "26/07/23"] + + table.sort(custom_sort, reverse=True) + assert table.get_row_at(0) == ["Doug Johnson", "26/07/23"] + assert table.get_row_at(1) == ["Albert Douglass", "16/07/23"] + assert table.get_row_at(2) == ["Jane Doe", "25/12/22"] From 69730d8ef648cd9067589c33c63b75b22a2b369d Mon Sep 17 00:00:00 2001 From: Josh Duncan <44387852+joshbduncan@users.noreply.github.com> Date: Mon, 28 Aug 2023 14:45:08 -0400 Subject: [PATCH 02/10] argument change and functionaloty update Changed back to orinal `columns` argument and added a new `key` argument which takes a function (or other callable). This allows the PR to NOT BE a breaking change. --- CHANGELOG.md | 4 +- docs/examples/widgets/data_table_sort.py | 34 +++++++--- docs/widgets/data_table.md | 6 +- src/textual/widgets/_data_table.py | 28 ++++++--- .../snapshot_apps/data_table_sort.py | 2 +- tests/test_data_table.py | 63 ++++++++++++++----- 6 files changed, 101 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b486cfad2..fb8806cad2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,9 +18,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Added - Added an interface for replacing prompt of an individual option in an `OptionList` https://github.com/Textualize/textual/issues/2603 - Added `DirectoryTree.reload_node` method https://github.com/Textualize/textual/issues/2757 - -### Changed -- The `DataTable` widget now takes the `by` argument instead of `columns`, allowing the table to also be sorted using a custom function (or other callable). This is a breaking change since it requires all calls to the `sort` method to include an iterable of key(s) (or a singular function/callable). https://github.com/Textualize/textual/issues/2261 +- Added `key` argument to the `DataTable.sort()` method, allowing the table to be sorted using a custom function (or other callable). Providing both `columns` and a `key` will limit the data data sent to the key function to only those columns. https://github.com/Textualize/textual/issues/2261 ## [0.32.0] - 2023-08-03 diff --git a/docs/examples/widgets/data_table_sort.py b/docs/examples/widgets/data_table_sort.py index d6afea74fa..1844b78d9f 100644 --- a/docs/examples/widgets/data_table_sort.py +++ b/docs/examples/widgets/data_table_sort.py @@ -1,6 +1,7 @@ from rich.text import Text from textual.app import App, ComposeResult +from textual.events import Click from textual.widgets import DataTable, Footer ROWS = [ @@ -46,21 +47,25 @@ def sort_reverse(self, sort_type: str): return reverse def action_sort_by_average_time(self) -> None: - """Sort DataTable by average of times (via a function).""" + """Sort DataTable by average of times (via a function) and + passing of column data through positional arguments.""" - def sort_by_average_time(row): - _, row_data = row - times = [row_data["time 1"], row_data["time 2"]] - return sum(times) / len(times) + def sort_by_average_time(times): + return sum(n for n in times) / len(times) table = self.query_one(DataTable) - table.sort(sort_by_average_time, reverse=self.sort_reverse("time")) + table.sort( + "time 1", + "time 2", + key=sort_by_average_time, + reverse=self.sort_reverse("time"), + ) def action_sort_by_last_name(self) -> None: """Sort DataTable by last name of swimmer (via a lambda).""" table = self.query_one(DataTable) table.sort( - lambda row: row[1]["swimmer"].split()[-1], + key=lambda row: row[1]["swimmer"].split()[-1], reverse=self.sort_reverse("swimmer"), ) @@ -68,10 +73,23 @@ def action_sort_by_country(self) -> None: """Sort DataTable by country which is a `Rich.Text` object.""" table = self.query_one(DataTable) table.sort( - lambda row: row[1]["country"].plain, + "country", + key=lambda country: country.plain, reverse=self.sort_reverse("country"), ) + def on_data_table_header_selected(self, event: Click) -> None: + """Sort `DataTable` items by the clicked column header.""" + + def sort_by_plain_text(cell): + return cell.plain if isinstance(cell, Text) else cell + + column_key = event.column_key + table = self.query_one(DataTable) + table.sort( + column_key, key=sort_by_plain_text, reverse=self.sort_reverse(column_key) + ) + app = TableApp() if __name__ == "__main__": diff --git a/docs/widgets/data_table.md b/docs/widgets/data_table.md index 46af1bfca1..bd4919c733 100644 --- a/docs/widgets/data_table.md +++ b/docs/widgets/data_table.md @@ -143,9 +143,11 @@ visible as you scroll through the data table. ### Sorting -The `DataTable` can be sorted using the [sort][textual.widgets.DataTable.sort] method. In order to sort your data by a column, you can provide the `key` you supplied to the `add_column` method or a `ColumnKey`. You can then pass this key to the `sort` method inside of an iterable. To sort by multiple columns, pass an iterable with multiple keys to `sort`. +The `DataTable` can be sorted using the [sort][textual.widgets.DataTable.sort] method. In order to sort your data by a column, you can provide the `key` you supplied to the `add_column` method or a `ColumnKey`. You can then pass one more column keys to the `sort` method to sort by one or more columns. -Additionally, you can sort your `DataTable` with a custom function (or other callable). Similar to the `key` parameter of the built-in [sorted()](https://docs.python.org/3/library/functions.html#sorted) function, your function (or other callable) should take a single argument (row) and returns a key to use for sorting purposes. +Additionally, you can sort your `DataTable` with a custom function (or other callable) via the `key` argument. Similar to the `key` parameter of the built-in [sorted()](https://docs.python.org/3/library/functions.html#sorted) function, your function (or other callable) should take a single argument (row) and return a key to use for sorting purposes. + +Providing both `columns` and `key` will limit the row information sent to your `key` function (or other callable) to only the columns specified. === "Output" diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index c6ea2b4958..1faed042f3 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -2107,15 +2107,18 @@ def _get_fixed_offset(self) -> Spacing: def sort( self, - by: Iterable[ColumnKey | str] | Callable, + *columns: ColumnKey | str, + key: Callable | None = None, reverse: bool = False, ) -> Self: """Sort the rows in the `DataTable` by one or more column keys or a - key function (or other callable). + key function (or other callable). If both columns and a key function + are specified, only data from those columns will sent to the key function. Args: - by: One or more columns to sort by the values by, or a function - (or other callable) that returns a key to use for sorting purposes. + columns: One or more columns to sort by the values in. + key: A function (or other callable) that returns a key to + use for sorting purposes. reverse: If True, the sort order will be reversed. Returns: @@ -2126,13 +2129,22 @@ def sort_by_column_keys( row: tuple[RowKey, dict[ColumnKey | str, CellType]] ) -> Any: _, row_data = row - result = itemgetter(*by)(row_data) + result = itemgetter(*columns)(row_data) return result - key = by if isinstance(by, Callable) else sort_by_column_keys - ordered_rows = sorted(self._data.items(), key=key, reverse=reverse) + _key = key + if key and columns: + + def _key(row): + return key(itemgetter(*columns)(row[1])) + + ordered_rows = sorted( + self._data.items(), + key=_key if _key else sort_by_column_keys, + reverse=reverse, + ) self._row_locations = TwoWayDict( - {key: new_index for new_index, (key, _) in enumerate(ordered_rows)} + {row_key: new_index for new_index, (row_key, _) in enumerate(ordered_rows)} ) self._update_count += 1 self.refresh() diff --git a/tests/snapshot_tests/snapshot_apps/data_table_sort.py b/tests/snapshot_tests/snapshot_apps/data_table_sort.py index dacf8f3a8d..b4866e0ca5 100644 --- a/tests/snapshot_tests/snapshot_apps/data_table_sort.py +++ b/tests/snapshot_tests/snapshot_apps/data_table_sort.py @@ -36,7 +36,7 @@ def on_mount(self) -> None: def action_sort(self): table = self.query_one(DataTable) - table.sort(["time", "lane"]) + table.sort("time", "lane") app = TableApp() diff --git a/tests/test_data_table.py b/tests/test_data_table.py index a92a33b467..fb46343e8d 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -478,7 +478,7 @@ async def test_get_row(): assert table.get_row(second_row) == [3, 2, 1] # Even if row positions change, keys should always refer to same rows. - table.sort([b]) + table.sort(b) assert table.get_row(first_row) == [2, 4, 1] assert table.get_row(second_row) == [3, 2, 1] @@ -502,7 +502,7 @@ async def test_get_row_at(): assert table.get_row_at(1) == [3, 2, 1] # If we sort, then the rows present at the indices *do* change! - table.sort([b]) + table.sort(b) # Since we sorted on column "B", the rows at indices 0 and 1 are swapped. assert table.get_row_at(0) == [3, 2, 1] @@ -878,7 +878,7 @@ async def test_sort_coordinate_and_key_access(): assert table.get_cell_at(Coordinate(1, 0)) == 1 assert table.get_cell_at(Coordinate(2, 0)) == 2 - table.sort([column]) + table.sort(column) # The keys still refer to the same cells... assert table.get_cell(row_one, column) == 1 @@ -911,7 +911,7 @@ async def test_sort_reverse_coordinate_and_key_access(): assert table.get_cell_at(Coordinate(1, 0)) == 1 assert table.get_cell_at(Coordinate(2, 0)) == 2 - table.sort([column], reverse=True) + table.sort(column, reverse=True) # The keys still refer to the same cells... assert table.get_cell(row_one, column) == 1 @@ -1176,22 +1176,22 @@ async def test_sort_by_multiple_columns(): assert table.get_row_at(1) == [2, 9, 5] assert table.get_row_at(2) == [1, 1, 9] - table.sort([a, b, c]) + table.sort(a, b, c) assert table.get_row_at(0) == [1, 1, 9] assert table.get_row_at(1) == [1, 3, 8] assert table.get_row_at(2) == [2, 9, 5] - table.sort([a, c, b]) + table.sort(a, c, b) assert table.get_row_at(0) == [1, 3, 8] assert table.get_row_at(1) == [1, 1, 9] assert table.get_row_at(2) == [2, 9, 5] - table.sort([c, a, b], reverse=True) + table.sort(c, a, b, reverse=True) assert table.get_row_at(0) == [1, 1, 9] assert table.get_row_at(1) == [1, 3, 8] assert table.get_row_at(2) == [2, 9, 5] - table.sort([a, c]) + table.sort(a, c) assert table.get_row_at(0) == [1, 3, 8] assert table.get_row_at(1) == [1, 1, 9] assert table.get_row_at(2) == [2, 9, 5] @@ -1220,12 +1220,12 @@ def custom_sort(row): assert table.get_row_at(i) == row assert sum(table.get_row_at(i)) == sum(row) - table.sort(custom_sort) + table.sort(key=custom_sort) sorted_row_data = sorted(row_data, key=sum) for i, row in enumerate(sorted_row_data): assert table.get_row_at(i) == row - table.sort(custom_sort, reverse=True) + table.sort(key=custom_sort, reverse=True) sorted_row_data = sorted(row_data, key=sum, reverse=True) for i, row in enumerate(sorted_row_data): assert table.get_row_at(i) == row @@ -1249,12 +1249,12 @@ async def test_sort_by_function_sum_lambda(): assert table.get_row_at(i) == row assert sum(table.get_row_at(i)) == sum(row) - table.sort(lambda row: sum(row[1].values())) + table.sort(key=lambda row: sum(row[1].values())) sorted_row_data = sorted(row_data, key=sum) for i, row in enumerate(sorted_row_data): assert table.get_row_at(i) == row - table.sort(lambda row: sum(row[1].values()), reverse=True) + table.sort(key=lambda row: sum(row[1].values()), reverse=True) sorted_row_data = sorted(row_data, key=sum, reverse=True) for i, row in enumerate(sorted_row_data): assert table.get_row_at(i) == row @@ -1284,12 +1284,47 @@ def custom_sort(row): assert table.get_row_at(1) == ["Albert Douglass", "16/07/23"] assert table.get_row_at(2) == ["Jane Doe", "25/12/22"] - table.sort(custom_sort) + table.sort(key=custom_sort) assert table.get_row_at(0) == ["Jane Doe", "25/12/22"] assert table.get_row_at(1) == ["Albert Douglass", "16/07/23"] assert table.get_row_at(2) == ["Doug Johnson", "26/07/23"] - table.sort(custom_sort, reverse=True) + table.sort(key=custom_sort, reverse=True) assert table.get_row_at(0) == ["Doug Johnson", "26/07/23"] assert table.get_row_at(1) == ["Albert Douglass", "16/07/23"] assert table.get_row_at(2) == ["Jane Doe", "25/12/22"] + + +async def test_sort_by_columns_and_function(): + """Test sorting a `DataTable` using a custom sort function and + only supplying data from specific columns.""" + + def custom_sort(nums): + return sum(n for n in nums) + + row_data = ( + ["A", "B", "C"], # Column headers + [1, 3, 8], # First and last columns SUM=9 + [2, 9, 5], # First and last columns SUM=7 + [1, 1, 9], # First and last columns SUM=10 + ) + + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + + for col in row_data[0]: + table.add_column(col, key=col) + table.add_rows(row_data[1:]) + + table.sort("A", "C", key=custom_sort) + sorted_row_data = sorted(row_data[1:], key=lambda row: row[0] + row[-1]) + for i, row in enumerate(sorted_row_data): + assert table.get_row_at(i) == row + + table.sort("A", "C", key=custom_sort, reverse=True) + sorted_row_data = sorted( + row_data[1:], key=lambda row: row[0] + row[-1], reverse=True + ) + for i, row in enumerate(sorted_row_data): + assert table.get_row_at(i) == row From 7cd0494be5fe5e03f47927b06308b7314728d857 Mon Sep 17 00:00:00 2001 From: Josh Duncan <44387852+joshbduncan@users.noreply.github.com> Date: Mon, 28 Aug 2023 21:28:56 -0400 Subject: [PATCH 03/10] better example for docs - Updated the example file for the docs to better show the functionality of the change (especially when using `columns` and `key` together). - Added one new tests to cover a similar situation to the example changes --- docs/examples/widgets/data_table_sort.py | 13 +++++---- tests/test_data_table.py | 37 ++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/docs/examples/widgets/data_table_sort.py b/docs/examples/widgets/data_table_sort.py index 1844b78d9f..73669601be 100644 --- a/docs/examples/widgets/data_table_sort.py +++ b/docs/examples/widgets/data_table_sort.py @@ -7,7 +7,7 @@ ROWS = [ ("lane", "swimmer", "country", "time 1", "time 2"), (4, "Joseph Schooling", Text("Singapore", style="italic"), 50.39, 51.84), - (2, "Michael Phelps", Text("United States", style="italic"), 51.14, 51.84), + (2, "Michael Phelps", Text("United States", style="italic"), 50.39, 51.84), (5, "Chad le Clos", Text("South Africa", style="italic"), 51.14, 51.73), (6, "László Cseh", Text("Hungary", style="italic"), 51.14, 51.58), (3, "Li Zhuhao", Text("China", style="italic"), 51.26, 51.26), @@ -50,14 +50,16 @@ def action_sort_by_average_time(self) -> None: """Sort DataTable by average of times (via a function) and passing of column data through positional arguments.""" - def sort_by_average_time(times): - return sum(n for n in times) / len(times) + def sort_by_average_time_then_last_name(row_data): + name, *scores = row_data + return (sum(scores) / len(scores), name.split()[-1]) table = self.query_one(DataTable) table.sort( + "swimmer", "time 1", "time 2", - key=sort_by_average_time, + key=sort_by_average_time_then_last_name, reverse=self.sort_reverse("time"), ) @@ -65,7 +67,8 @@ def action_sort_by_last_name(self) -> None: """Sort DataTable by last name of swimmer (via a lambda).""" table = self.query_one(DataTable) table.sort( - key=lambda row: row[1]["swimmer"].split()[-1], + "swimmer", + key=lambda swimmer: swimmer.split()[-1], reverse=self.sort_reverse("swimmer"), ) diff --git a/tests/test_data_table.py b/tests/test_data_table.py index fb46343e8d..37f583e251 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -1328,3 +1328,40 @@ def custom_sort(nums): ) for i, row in enumerate(sorted_row_data): assert table.get_row_at(i) == row + + +async def test_sort_by_multiple_columns_and_function(): + """Test sorting a `DataTable` using a custom sort function and + only supplying data from specific columns.""" + + def custom_sort(row_data): + name, *scores = row_data + return (sum(scores) / len(scores), name.split()[-1]) + + row_data = ( + ["ID", "Student", "Participation", "Test 1", "Test 2", "Test 3"], + ["ID-01", "Joseph Schooling", True, 90, 91, 92], + ["ID-02", "Li Zhuhao", False, 92, 93, 94], + ["ID-03", "Chad le Clos", False, 92, 93, 94], + ["ID-04", "Michael Phelps", True, 95, 96, 99], + ) + + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + + for col in row_data[0]: + table.add_column(col, key=col) + table.add_rows(row_data[1:]) + + table.sort("Student", "Test 1", "Test 2", "Test 3", key=custom_sort) + sorted_row_data = (row_data[1], row_data[3], row_data[2], row_data[4]) + for i, row in enumerate(sorted_row_data): + assert table.get_row_at(i) == row + + table.sort( + "Student", "Test 1", "Test 2", "Test 3", key=custom_sort, reverse=True + ) + sorted_row_data = (row_data[4], row_data[2], row_data[3], row_data[1]) + for i, row in enumerate(sorted_row_data): + assert table.get_row_at(i) == row From 312f9acae5b0e5532be45cf57d7052078fea3f74 Mon Sep 17 00:00:00 2001 From: Josh Duncan <44387852+joshbduncan@users.noreply.github.com> Date: Tue, 29 Aug 2023 23:47:35 -0400 Subject: [PATCH 04/10] removed unecessary code from example - the sort by clicked column function was bloat in my opinion --- docs/examples/widgets/data_table_sort.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/docs/examples/widgets/data_table_sort.py b/docs/examples/widgets/data_table_sort.py index 73669601be..d88223b9af 100644 --- a/docs/examples/widgets/data_table_sort.py +++ b/docs/examples/widgets/data_table_sort.py @@ -1,7 +1,6 @@ from rich.text import Text from textual.app import App, ComposeResult -from textual.events import Click from textual.widgets import DataTable, Footer ROWS = [ @@ -81,18 +80,6 @@ def action_sort_by_country(self) -> None: reverse=self.sort_reverse("country"), ) - def on_data_table_header_selected(self, event: Click) -> None: - """Sort `DataTable` items by the clicked column header.""" - - def sort_by_plain_text(cell): - return cell.plain if isinstance(cell, Text) else cell - - column_key = event.column_key - table = self.query_one(DataTable) - table.sort( - column_key, key=sort_by_plain_text, reverse=self.sort_reverse(column_key) - ) - app = TableApp() if __name__ == "__main__": From 5f71d0f3712dff0ec62f3cdd8ff20e7598581ef0 Mon Sep 17 00:00:00 2001 From: Josh Duncan <44387852+joshbduncan@users.noreply.github.com> Date: Thu, 21 Sep 2023 11:29:03 -0400 Subject: [PATCH 05/10] requested changes --- CHANGELOG.md | 2 +- src/textual/widgets/_data_table.py | 11 ++--- tests/test_data_table.py | 69 +++++++----------------------- 3 files changed, 23 insertions(+), 59 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fb8806cad2..b0875aa41f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Added - Added an interface for replacing prompt of an individual option in an `OptionList` https://github.com/Textualize/textual/issues/2603 - Added `DirectoryTree.reload_node` method https://github.com/Textualize/textual/issues/2757 -- Added `key` argument to the `DataTable.sort()` method, allowing the table to be sorted using a custom function (or other callable). Providing both `columns` and a `key` will limit the data data sent to the key function to only those columns. https://github.com/Textualize/textual/issues/2261 +- Added `key` argument to the `DataTable.sort()` method, allowing the table to be sorted using a custom function (or other callable) https://github.com/Textualize/textual/issues/2261 ## [0.32.0] - 2023-08-03 diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 1faed042f3..ebc819c943 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -2108,7 +2108,7 @@ def _get_fixed_offset(self) -> Spacing: def sort( self, *columns: ColumnKey | str, - key: Callable | None = None, + key: Callable[[Any], Any] | None = None, reverse: bool = False, ) -> Self: """Sort the rows in the `DataTable` by one or more column keys or a @@ -2133,14 +2133,15 @@ def sort_by_column_keys( return result _key = key - if key and columns: + if key: - def _key(row): - return key(itemgetter(*columns)(row[1])) + def _key(row: tuple[RowKey, dict[ColumnKey | str, CellType]]) -> Any: + _, row_data = row + return key(itemgetter(*columns)(row_data)) ordered_rows = sorted( self._data.items(), - key=_key if _key else sort_by_column_keys, + key=_key if key is not None else sort_by_column_keys, reverse=reverse, ) self._row_locations = TwoWayDict( diff --git a/tests/test_data_table.py b/tests/test_data_table.py index 37f583e251..1c943ee287 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -1165,6 +1165,7 @@ async def test_unset_hover_highlight_when_no_table_cell_under_mouse(): async def test_sort_by_multiple_columns(): """Test sorting a `DataTable` my multiple columns.""" + app = DataTableApp() async with app.run_test(): table = app.query_one(DataTable) @@ -1200,9 +1201,8 @@ async def test_sort_by_multiple_columns(): async def test_sort_by_function_sum(): """Test sorting a `DataTable` using a custom sort function.""" - def custom_sort(row): - _, row_data = row - return sum(row_data.values()) + def custom_sort(row_data): + return sum(row_data) row_data = ( [1, 3, 8], # SUM=12 @@ -1220,19 +1220,19 @@ def custom_sort(row): assert table.get_row_at(i) == row assert sum(table.get_row_at(i)) == sum(row) - table.sort(key=custom_sort) + table.sort(a, b, c, key=custom_sort) sorted_row_data = sorted(row_data, key=sum) for i, row in enumerate(sorted_row_data): assert table.get_row_at(i) == row - table.sort(key=custom_sort, reverse=True) + table.sort(a, b, c, key=custom_sort, reverse=True) sorted_row_data = sorted(row_data, key=sum, reverse=True) for i, row in enumerate(sorted_row_data): assert table.get_row_at(i) == row -async def test_sort_by_function_sum_lambda(): - """Test sorting a `DataTable` using a custom sort lambda.""" +async def test_sort_by_lambda_function(): + """Test sorting a `DataTable` using lambda function.""" row_data = ( [1, 3, 8], # SUM=12 [2, 9, 5], # SUM=16 @@ -1249,12 +1249,12 @@ async def test_sort_by_function_sum_lambda(): assert table.get_row_at(i) == row assert sum(table.get_row_at(i)) == sum(row) - table.sort(key=lambda row: sum(row[1].values())) + table.sort(a, b, c, key=lambda row_data: sum(row_data)) sorted_row_data = sorted(row_data, key=sum) for i, row in enumerate(sorted_row_data): assert table.get_row_at(i) == row - table.sort(key=lambda row: sum(row[1].values()), reverse=True) + table.sort(a, b, c, key=lambda row_data: sum(row_data), reverse=True) sorted_row_data = sorted(row_data, key=sum, reverse=True) for i, row in enumerate(sorted_row_data): assert table.get_row_at(i) == row @@ -1267,10 +1267,8 @@ async def test_sort_by_function_date(): Example: 26/07/23 == 2023-07-26 """ - def custom_sort(row): - _, row_data = row - result = itemgetter(birthdate)(row_data) - d, m, y = result.split("/") + def custom_sort(birthdate): + d, m, y = birthdate.split("/") return f"{y}{m}{d}" app = DataTableApp() @@ -1284,55 +1282,20 @@ def custom_sort(row): assert table.get_row_at(1) == ["Albert Douglass", "16/07/23"] assert table.get_row_at(2) == ["Jane Doe", "25/12/22"] - table.sort(key=custom_sort) + table.sort(birthdate, key=custom_sort) assert table.get_row_at(0) == ["Jane Doe", "25/12/22"] assert table.get_row_at(1) == ["Albert Douglass", "16/07/23"] assert table.get_row_at(2) == ["Doug Johnson", "26/07/23"] - table.sort(key=custom_sort, reverse=True) + table.sort(birthdate, key=custom_sort, reverse=True) assert table.get_row_at(0) == ["Doug Johnson", "26/07/23"] assert table.get_row_at(1) == ["Albert Douglass", "16/07/23"] assert table.get_row_at(2) == ["Jane Doe", "25/12/22"] -async def test_sort_by_columns_and_function(): - """Test sorting a `DataTable` using a custom sort function and - only supplying data from specific columns.""" - - def custom_sort(nums): - return sum(n for n in nums) - - row_data = ( - ["A", "B", "C"], # Column headers - [1, 3, 8], # First and last columns SUM=9 - [2, 9, 5], # First and last columns SUM=7 - [1, 1, 9], # First and last columns SUM=10 - ) - - app = DataTableApp() - async with app.run_test(): - table = app.query_one(DataTable) - - for col in row_data[0]: - table.add_column(col, key=col) - table.add_rows(row_data[1:]) - - table.sort("A", "C", key=custom_sort) - sorted_row_data = sorted(row_data[1:], key=lambda row: row[0] + row[-1]) - for i, row in enumerate(sorted_row_data): - assert table.get_row_at(i) == row - - table.sort("A", "C", key=custom_sort, reverse=True) - sorted_row_data = sorted( - row_data[1:], key=lambda row: row[0] + row[-1], reverse=True - ) - for i, row in enumerate(sorted_row_data): - assert table.get_row_at(i) == row - - -async def test_sort_by_multiple_columns_and_function(): - """Test sorting a `DataTable` using a custom sort function and - only supplying data from specific columns.""" +async def test_sort_by_function_retuning_multiple_values(): + """Test sorting a `DataTable` using a custom sort function + that returns multiple for the row to be sorted by.""" def custom_sort(row_data): name, *scores = row_data From a0839ccc21cd01645bd6dbf03178df44681e8ea8 Mon Sep 17 00:00:00 2001 From: Josh Duncan <44387852+joshbduncan@users.noreply.github.com> Date: Wed, 11 Oct 2023 23:27:18 -0400 Subject: [PATCH 06/10] simplify method and terminology --- src/textual/widgets/_data_table.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index ebc819c943..bad0298ed6 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -2132,16 +2132,13 @@ def sort_by_column_keys( result = itemgetter(*columns)(row_data) return result - _key = key - if key: - - def _key(row: tuple[RowKey, dict[ColumnKey | str, CellType]]) -> Any: - _, row_data = row - return key(itemgetter(*columns)(row_data)) + def key_wrapper(row: tuple[RowKey, dict[ColumnKey | str, CellType]]) -> Any: + _, row_data = row + return key(itemgetter(*columns)(row_data)) ordered_rows = sorted( self._data.items(), - key=_key if key is not None else sort_by_column_keys, + key=key_wrapper if key is not None else sort_by_column_keys, reverse=reverse, ) self._row_locations = TwoWayDict( From 948b81617ed46db126984331c7cb2e40592726c7 Mon Sep 17 00:00:00 2001 From: Josh Duncan <44387852+joshbduncan@users.noreply.github.com> Date: Tue, 24 Oct 2023 16:59:17 -0400 Subject: [PATCH 07/10] combine key_wrapper and default sort --- docs/examples/widgets/data_table_sort.py | 6 +++++ src/textual/widgets/_data_table.py | 17 ++++++------- tests/test_data_table.py | 31 ++++++++++++++++++++++-- 3 files changed, 43 insertions(+), 11 deletions(-) diff --git a/docs/examples/widgets/data_table_sort.py b/docs/examples/widgets/data_table_sort.py index d88223b9af..599a629394 100644 --- a/docs/examples/widgets/data_table_sort.py +++ b/docs/examples/widgets/data_table_sort.py @@ -22,6 +22,7 @@ class TableApp(App): ("a", "sort_by_average_time", "Sort By Average Time"), ("n", "sort_by_last_name", "Sort By Last Name"), ("c", "sort_by_country", "Sort By Country"), + ("d", "sort_by_columns", "Sort By Columns (Only)"), ] current_sorts: set = set() @@ -80,6 +81,11 @@ def action_sort_by_country(self) -> None: reverse=self.sort_reverse("country"), ) + def action_sort_by_columns(self) -> None: + """Sort DataTable without a key.""" + table = self.query_one(DataTable) + table.sort("swimmer", "lane", reverse=self.sort_reverse("columns")) + app = TableApp() if __name__ == "__main__": diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 2c20da5107..c90453d854 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -2366,20 +2366,19 @@ def sort( The `DataTable` instance. """ - def sort_by_column_keys( - row: tuple[RowKey, dict[ColumnKey | str, CellType]] - ) -> Any: - _, row_data = row - result = itemgetter(*columns)(row_data) - return result - def key_wrapper(row: tuple[RowKey, dict[ColumnKey | str, CellType]]) -> Any: _, row_data = row - return key(itemgetter(*columns)(row_data)) + if columns: + result = itemgetter(*columns)(row_data) + else: + result = tuple(row_data.values()) + if key is not None: + return key(result) + return result ordered_rows = sorted( self._data.items(), - key=key_wrapper if key is not None else sort_by_column_keys, + key=key_wrapper, reverse=reverse, ) self._row_locations = TwoWayDict( diff --git a/tests/test_data_table.py b/tests/test_data_table.py index cdfb705895..1d7d4b8113 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -1199,8 +1199,33 @@ async def test_unset_hover_highlight_when_no_table_cell_under_mouse(): assert not table._show_hover_cursor -async def test_sort_by_multiple_columns(): - """Test sorting a `DataTable` my multiple columns.""" +async def test_sort_by_all_columns_no_key(): + """Test sorting a `DataTable` by all columns.""" + + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + a, b, c = table.add_columns("A", "B", "C") + table.add_row(1, 3, 8) + table.add_row(2, 9, 5) + table.add_row(1, 1, 9) + assert table.get_row_at(0) == [1, 3, 8] + assert table.get_row_at(1) == [2, 9, 5] + assert table.get_row_at(2) == [1, 1, 9] + + table.sort() + assert table.get_row_at(0) == [1, 1, 9] + assert table.get_row_at(1) == [1, 3, 8] + assert table.get_row_at(2) == [2, 9, 5] + + table.sort(reverse=True) + assert table.get_row_at(0) == [2, 9, 5] + assert table.get_row_at(1) == [1, 3, 8] + assert table.get_row_at(2) == [1, 1, 9] + + +async def test_sort_by_multiple_columns_no_key(): + """Test sorting a `DataTable` by multiple columns.""" app = DataTableApp() async with app.run_test(): @@ -1364,6 +1389,8 @@ def custom_sort(row_data): sorted_row_data = (row_data[4], row_data[2], row_data[3], row_data[1]) for i, row in enumerate(sorted_row_data): assert table.get_row_at(i) == row + + @pytest.mark.parametrize( ["cell", "height"], [ From 5db5c1af76ac8c65e16b503ff0cac3e644f8b9d8 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Fri, 27 Oct 2023 16:16:39 +0100 Subject: [PATCH 08/10] Removing some tests from DataTable.sort as duplicates. Ensure there is test coverage of the case where a key, but no columns, is passed to DataTable.sort. --- tests/test_data_table.py | 106 ++------------------------------------- 1 file changed, 4 insertions(+), 102 deletions(-) diff --git a/tests/test_data_table.py b/tests/test_data_table.py index 1d7d4b8113..3f201bd600 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -1234,9 +1234,6 @@ async def test_sort_by_multiple_columns_no_key(): table.add_row(1, 3, 8) table.add_row(2, 9, 5) table.add_row(1, 1, 9) - assert table.get_row_at(0) == [1, 3, 8] - assert table.get_row_at(1) == [2, 9, 5] - assert table.get_row_at(2) == [1, 1, 9] table.sort(a, b, c) assert table.get_row_at(0) == [1, 1, 9] @@ -1275,122 +1272,27 @@ def custom_sort(row_data): async with app.run_test(): table = app.query_one(DataTable) a, b, c = table.add_columns("A", "B", "C") - for i, row in enumerate(row_data): table.add_row(*row) - assert table.get_row_at(i) == row - assert sum(table.get_row_at(i)) == sum(row) + # Sorting by all columns table.sort(a, b, c, key=custom_sort) sorted_row_data = sorted(row_data, key=sum) for i, row in enumerate(sorted_row_data): assert table.get_row_at(i) == row - table.sort(a, b, c, key=custom_sort, reverse=True) - sorted_row_data = sorted(row_data, key=sum, reverse=True) - for i, row in enumerate(sorted_row_data): - assert table.get_row_at(i) == row - - -async def test_sort_by_lambda_function(): - """Test sorting a `DataTable` using lambda function.""" - row_data = ( - [1, 3, 8], # SUM=12 - [2, 9, 5], # SUM=16 - [1, 1, 9], # SUM=11 - ) - - app = DataTableApp() - async with app.run_test(): - table = app.query_one(DataTable) - a, b, c = table.add_columns("A", "B", "C") - - for i, row in enumerate(row_data): - table.add_row(*row) - assert table.get_row_at(i) == row - assert sum(table.get_row_at(i)) == sum(row) - - table.sort(a, b, c, key=lambda row_data: sum(row_data)) + # Passing a sort function but no columns also sorts by all columns + table.sort(key=custom_sort) sorted_row_data = sorted(row_data, key=sum) for i, row in enumerate(sorted_row_data): assert table.get_row_at(i) == row - table.sort(a, b, c, key=lambda row_data: sum(row_data), reverse=True) + table.sort(a, b, c, key=custom_sort, reverse=True) sorted_row_data = sorted(row_data, key=sum, reverse=True) for i, row in enumerate(sorted_row_data): assert table.get_row_at(i) == row -async def test_sort_by_function_date(): - """Test sorting a `DataTable` by date when an abnormal date format is used. - Based on a question in the Discourse 'Help Wanted' section. - - Example: 26/07/23 == 2023-07-26 - """ - - def custom_sort(birthdate): - d, m, y = birthdate.split("/") - return f"{y}{m}{d}" - - app = DataTableApp() - async with app.run_test(): - table = app.query_one(DataTable) - _, birthdate = table.add_columns("Name", "Birthdate") - table.add_row("Doug Johnson", "26/07/23") - table.add_row("Albert Douglass", "16/07/23") - table.add_row("Jane Doe", "25/12/22") - assert table.get_row_at(0) == ["Doug Johnson", "26/07/23"] - assert table.get_row_at(1) == ["Albert Douglass", "16/07/23"] - assert table.get_row_at(2) == ["Jane Doe", "25/12/22"] - - table.sort(birthdate, key=custom_sort) - assert table.get_row_at(0) == ["Jane Doe", "25/12/22"] - assert table.get_row_at(1) == ["Albert Douglass", "16/07/23"] - assert table.get_row_at(2) == ["Doug Johnson", "26/07/23"] - - table.sort(birthdate, key=custom_sort, reverse=True) - assert table.get_row_at(0) == ["Doug Johnson", "26/07/23"] - assert table.get_row_at(1) == ["Albert Douglass", "16/07/23"] - assert table.get_row_at(2) == ["Jane Doe", "25/12/22"] - - -async def test_sort_by_function_retuning_multiple_values(): - """Test sorting a `DataTable` using a custom sort function - that returns multiple for the row to be sorted by.""" - - def custom_sort(row_data): - name, *scores = row_data - return (sum(scores) / len(scores), name.split()[-1]) - - row_data = ( - ["ID", "Student", "Participation", "Test 1", "Test 2", "Test 3"], - ["ID-01", "Joseph Schooling", True, 90, 91, 92], - ["ID-02", "Li Zhuhao", False, 92, 93, 94], - ["ID-03", "Chad le Clos", False, 92, 93, 94], - ["ID-04", "Michael Phelps", True, 95, 96, 99], - ) - - app = DataTableApp() - async with app.run_test(): - table = app.query_one(DataTable) - - for col in row_data[0]: - table.add_column(col, key=col) - table.add_rows(row_data[1:]) - - table.sort("Student", "Test 1", "Test 2", "Test 3", key=custom_sort) - sorted_row_data = (row_data[1], row_data[3], row_data[2], row_data[4]) - for i, row in enumerate(sorted_row_data): - assert table.get_row_at(i) == row - - table.sort( - "Student", "Test 1", "Test 2", "Test 3", key=custom_sort, reverse=True - ) - sorted_row_data = (row_data[4], row_data[2], row_data[3], row_data[1]) - for i, row in enumerate(sorted_row_data): - assert table.get_row_at(i) == row - - @pytest.mark.parametrize( ["cell", "height"], [ From 9d17690fe13a4b86e6c3bfb7e8c36c1490d56fa5 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Fri, 27 Oct 2023 18:05:44 +0100 Subject: [PATCH 09/10] Remove unused import --- tests/test_data_table.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_data_table.py b/tests/test_data_table.py index 3f201bd600..4f473f978f 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -1,7 +1,5 @@ from __future__ import annotations -from operator import itemgetter - import pytest from rich.panel import Panel from rich.text import Text From 7541867e37f67f123b1c1f1df74719a88aca3452 Mon Sep 17 00:00:00 2001 From: Darren Burns Date: Fri, 27 Oct 2023 18:07:52 +0100 Subject: [PATCH 10/10] Fix merge issues in CHANGELOG, update DataTable sort-by-key changelog PR link --- CHANGELOG.md | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 128b124789..5baf8b8e7e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,10 +25,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Add Document `get_index_from_location` / `get_location_from_index` https://github.com/Textualize/textual/pull/3410 - Add setter for `TextArea.text` https://github.com/Textualize/textual/discussions/3525 +- Added `key` argument to the `DataTable.sort()` method, allowing the table to be sorted using a custom function (or other callable) https://github.com/Textualize/textual/pull/3090 +- Added `initial` to all css rules, which restores default (i.e. value from DEFAULT_CSS) https://github.com/Textualize/textual/pull/3566 +- Added HorizontalPad to pad.py https://github.com/Textualize/textual/pull/3571 +- Added `AwaitComplete` class, to be used for optionally awaitable return values https://github.com/Textualize/textual/pull/3498 -### Added - -- Added `key` argument to the `DataTable.sort()` method, allowing the table to be sorted using a custom function (or other callable) https://github.com/Textualize/textual/issues/2261 ### Changed @@ -50,15 +51,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Breaking change: empty rules now result in an error https://github.com/Textualize/textual/pull/3566 - Improved startup time by caching CSS parsing https://github.com/Textualize/textual/pull/3575 -### Added - -- Added `initial` to all css rules, which restores default (i.e. value from DEFAULT_CSS) https://github.com/Textualize/textual/pull/3566 -- Added HorizontalPad to pad.py https://github.com/Textualize/textual/pull/3571 - -### Added - -- Added `AwaitComplete` class, to be used for optionally awaitable return values https://github.com/Textualize/textual/pull/3498 - ## [0.40.0] - 2023-10-11 ### Added