From 69730d8ef648cd9067589c33c63b75b22a2b369d Mon Sep 17 00:00:00 2001 From: Josh Duncan <44387852+joshbduncan@users.noreply.github.com> Date: Mon, 28 Aug 2023 14:45:08 -0400 Subject: [PATCH] argument change and functionaloty update Changed back to orinal `columns` argument and added a new `key` argument which takes a function (or other callable). This allows the PR to NOT BE a breaking change. --- CHANGELOG.md | 4 +- docs/examples/widgets/data_table_sort.py | 34 +++++++--- docs/widgets/data_table.md | 6 +- src/textual/widgets/_data_table.py | 28 ++++++--- .../snapshot_apps/data_table_sort.py | 2 +- tests/test_data_table.py | 63 ++++++++++++++----- 6 files changed, 101 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b486cfad2..fb8806cad2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,9 +18,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Added - Added an interface for replacing prompt of an individual option in an `OptionList` https://github.com/Textualize/textual/issues/2603 - Added `DirectoryTree.reload_node` method https://github.com/Textualize/textual/issues/2757 - -### Changed -- The `DataTable` widget now takes the `by` argument instead of `columns`, allowing the table to also be sorted using a custom function (or other callable). This is a breaking change since it requires all calls to the `sort` method to include an iterable of key(s) (or a singular function/callable). https://github.com/Textualize/textual/issues/2261 +- Added `key` argument to the `DataTable.sort()` method, allowing the table to be sorted using a custom function (or other callable). Providing both `columns` and a `key` will limit the data data sent to the key function to only those columns. https://github.com/Textualize/textual/issues/2261 ## [0.32.0] - 2023-08-03 diff --git a/docs/examples/widgets/data_table_sort.py b/docs/examples/widgets/data_table_sort.py index d6afea74fa..1844b78d9f 100644 --- a/docs/examples/widgets/data_table_sort.py +++ b/docs/examples/widgets/data_table_sort.py @@ -1,6 +1,7 @@ from rich.text import Text from textual.app import App, ComposeResult +from textual.events import Click from textual.widgets import DataTable, Footer ROWS = [ @@ -46,21 +47,25 @@ def sort_reverse(self, sort_type: str): return reverse def action_sort_by_average_time(self) -> None: - """Sort DataTable by average of times (via a function).""" + """Sort DataTable by average of times (via a function) and + passing of column data through positional arguments.""" - def sort_by_average_time(row): - _, row_data = row - times = [row_data["time 1"], row_data["time 2"]] - return sum(times) / len(times) + def sort_by_average_time(times): + return sum(n for n in times) / len(times) table = self.query_one(DataTable) - table.sort(sort_by_average_time, reverse=self.sort_reverse("time")) + table.sort( + "time 1", + "time 2", + key=sort_by_average_time, + reverse=self.sort_reverse("time"), + ) def action_sort_by_last_name(self) -> None: """Sort DataTable by last name of swimmer (via a lambda).""" table = self.query_one(DataTable) table.sort( - lambda row: row[1]["swimmer"].split()[-1], + key=lambda row: row[1]["swimmer"].split()[-1], reverse=self.sort_reverse("swimmer"), ) @@ -68,10 +73,23 @@ def action_sort_by_country(self) -> None: """Sort DataTable by country which is a `Rich.Text` object.""" table = self.query_one(DataTable) table.sort( - lambda row: row[1]["country"].plain, + "country", + key=lambda country: country.plain, reverse=self.sort_reverse("country"), ) + def on_data_table_header_selected(self, event: Click) -> None: + """Sort `DataTable` items by the clicked column header.""" + + def sort_by_plain_text(cell): + return cell.plain if isinstance(cell, Text) else cell + + column_key = event.column_key + table = self.query_one(DataTable) + table.sort( + column_key, key=sort_by_plain_text, reverse=self.sort_reverse(column_key) + ) + app = TableApp() if __name__ == "__main__": diff --git a/docs/widgets/data_table.md b/docs/widgets/data_table.md index 46af1bfca1..bd4919c733 100644 --- a/docs/widgets/data_table.md +++ b/docs/widgets/data_table.md @@ -143,9 +143,11 @@ visible as you scroll through the data table. ### Sorting -The `DataTable` can be sorted using the [sort][textual.widgets.DataTable.sort] method. In order to sort your data by a column, you can provide the `key` you supplied to the `add_column` method or a `ColumnKey`. You can then pass this key to the `sort` method inside of an iterable. To sort by multiple columns, pass an iterable with multiple keys to `sort`. +The `DataTable` can be sorted using the [sort][textual.widgets.DataTable.sort] method. In order to sort your data by a column, you can provide the `key` you supplied to the `add_column` method or a `ColumnKey`. You can then pass one more column keys to the `sort` method to sort by one or more columns. -Additionally, you can sort your `DataTable` with a custom function (or other callable). Similar to the `key` parameter of the built-in [sorted()](https://docs.python.org/3/library/functions.html#sorted) function, your function (or other callable) should take a single argument (row) and returns a key to use for sorting purposes. +Additionally, you can sort your `DataTable` with a custom function (or other callable) via the `key` argument. Similar to the `key` parameter of the built-in [sorted()](https://docs.python.org/3/library/functions.html#sorted) function, your function (or other callable) should take a single argument (row) and return a key to use for sorting purposes. + +Providing both `columns` and `key` will limit the row information sent to your `key` function (or other callable) to only the columns specified. === "Output" diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index c6ea2b4958..1faed042f3 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -2107,15 +2107,18 @@ def _get_fixed_offset(self) -> Spacing: def sort( self, - by: Iterable[ColumnKey | str] | Callable, + *columns: ColumnKey | str, + key: Callable | None = None, reverse: bool = False, ) -> Self: """Sort the rows in the `DataTable` by one or more column keys or a - key function (or other callable). + key function (or other callable). If both columns and a key function + are specified, only data from those columns will sent to the key function. Args: - by: One or more columns to sort by the values by, or a function - (or other callable) that returns a key to use for sorting purposes. + columns: One or more columns to sort by the values in. + key: A function (or other callable) that returns a key to + use for sorting purposes. reverse: If True, the sort order will be reversed. Returns: @@ -2126,13 +2129,22 @@ def sort_by_column_keys( row: tuple[RowKey, dict[ColumnKey | str, CellType]] ) -> Any: _, row_data = row - result = itemgetter(*by)(row_data) + result = itemgetter(*columns)(row_data) return result - key = by if isinstance(by, Callable) else sort_by_column_keys - ordered_rows = sorted(self._data.items(), key=key, reverse=reverse) + _key = key + if key and columns: + + def _key(row): + return key(itemgetter(*columns)(row[1])) + + ordered_rows = sorted( + self._data.items(), + key=_key if _key else sort_by_column_keys, + reverse=reverse, + ) self._row_locations = TwoWayDict( - {key: new_index for new_index, (key, _) in enumerate(ordered_rows)} + {row_key: new_index for new_index, (row_key, _) in enumerate(ordered_rows)} ) self._update_count += 1 self.refresh() diff --git a/tests/snapshot_tests/snapshot_apps/data_table_sort.py b/tests/snapshot_tests/snapshot_apps/data_table_sort.py index dacf8f3a8d..b4866e0ca5 100644 --- a/tests/snapshot_tests/snapshot_apps/data_table_sort.py +++ b/tests/snapshot_tests/snapshot_apps/data_table_sort.py @@ -36,7 +36,7 @@ def on_mount(self) -> None: def action_sort(self): table = self.query_one(DataTable) - table.sort(["time", "lane"]) + table.sort("time", "lane") app = TableApp() diff --git a/tests/test_data_table.py b/tests/test_data_table.py index a92a33b467..fb46343e8d 100644 --- a/tests/test_data_table.py +++ b/tests/test_data_table.py @@ -478,7 +478,7 @@ async def test_get_row(): assert table.get_row(second_row) == [3, 2, 1] # Even if row positions change, keys should always refer to same rows. - table.sort([b]) + table.sort(b) assert table.get_row(first_row) == [2, 4, 1] assert table.get_row(second_row) == [3, 2, 1] @@ -502,7 +502,7 @@ async def test_get_row_at(): assert table.get_row_at(1) == [3, 2, 1] # If we sort, then the rows present at the indices *do* change! - table.sort([b]) + table.sort(b) # Since we sorted on column "B", the rows at indices 0 and 1 are swapped. assert table.get_row_at(0) == [3, 2, 1] @@ -878,7 +878,7 @@ async def test_sort_coordinate_and_key_access(): assert table.get_cell_at(Coordinate(1, 0)) == 1 assert table.get_cell_at(Coordinate(2, 0)) == 2 - table.sort([column]) + table.sort(column) # The keys still refer to the same cells... assert table.get_cell(row_one, column) == 1 @@ -911,7 +911,7 @@ async def test_sort_reverse_coordinate_and_key_access(): assert table.get_cell_at(Coordinate(1, 0)) == 1 assert table.get_cell_at(Coordinate(2, 0)) == 2 - table.sort([column], reverse=True) + table.sort(column, reverse=True) # The keys still refer to the same cells... assert table.get_cell(row_one, column) == 1 @@ -1176,22 +1176,22 @@ async def test_sort_by_multiple_columns(): assert table.get_row_at(1) == [2, 9, 5] assert table.get_row_at(2) == [1, 1, 9] - table.sort([a, b, c]) + table.sort(a, b, c) assert table.get_row_at(0) == [1, 1, 9] assert table.get_row_at(1) == [1, 3, 8] assert table.get_row_at(2) == [2, 9, 5] - table.sort([a, c, b]) + table.sort(a, c, b) assert table.get_row_at(0) == [1, 3, 8] assert table.get_row_at(1) == [1, 1, 9] assert table.get_row_at(2) == [2, 9, 5] - table.sort([c, a, b], reverse=True) + table.sort(c, a, b, reverse=True) assert table.get_row_at(0) == [1, 1, 9] assert table.get_row_at(1) == [1, 3, 8] assert table.get_row_at(2) == [2, 9, 5] - table.sort([a, c]) + table.sort(a, c) assert table.get_row_at(0) == [1, 3, 8] assert table.get_row_at(1) == [1, 1, 9] assert table.get_row_at(2) == [2, 9, 5] @@ -1220,12 +1220,12 @@ def custom_sort(row): assert table.get_row_at(i) == row assert sum(table.get_row_at(i)) == sum(row) - table.sort(custom_sort) + table.sort(key=custom_sort) sorted_row_data = sorted(row_data, key=sum) for i, row in enumerate(sorted_row_data): assert table.get_row_at(i) == row - table.sort(custom_sort, reverse=True) + table.sort(key=custom_sort, reverse=True) sorted_row_data = sorted(row_data, key=sum, reverse=True) for i, row in enumerate(sorted_row_data): assert table.get_row_at(i) == row @@ -1249,12 +1249,12 @@ async def test_sort_by_function_sum_lambda(): assert table.get_row_at(i) == row assert sum(table.get_row_at(i)) == sum(row) - table.sort(lambda row: sum(row[1].values())) + table.sort(key=lambda row: sum(row[1].values())) sorted_row_data = sorted(row_data, key=sum) for i, row in enumerate(sorted_row_data): assert table.get_row_at(i) == row - table.sort(lambda row: sum(row[1].values()), reverse=True) + table.sort(key=lambda row: sum(row[1].values()), reverse=True) sorted_row_data = sorted(row_data, key=sum, reverse=True) for i, row in enumerate(sorted_row_data): assert table.get_row_at(i) == row @@ -1284,12 +1284,47 @@ def custom_sort(row): assert table.get_row_at(1) == ["Albert Douglass", "16/07/23"] assert table.get_row_at(2) == ["Jane Doe", "25/12/22"] - table.sort(custom_sort) + table.sort(key=custom_sort) assert table.get_row_at(0) == ["Jane Doe", "25/12/22"] assert table.get_row_at(1) == ["Albert Douglass", "16/07/23"] assert table.get_row_at(2) == ["Doug Johnson", "26/07/23"] - table.sort(custom_sort, reverse=True) + table.sort(key=custom_sort, reverse=True) assert table.get_row_at(0) == ["Doug Johnson", "26/07/23"] assert table.get_row_at(1) == ["Albert Douglass", "16/07/23"] assert table.get_row_at(2) == ["Jane Doe", "25/12/22"] + + +async def test_sort_by_columns_and_function(): + """Test sorting a `DataTable` using a custom sort function and + only supplying data from specific columns.""" + + def custom_sort(nums): + return sum(n for n in nums) + + row_data = ( + ["A", "B", "C"], # Column headers + [1, 3, 8], # First and last columns SUM=9 + [2, 9, 5], # First and last columns SUM=7 + [1, 1, 9], # First and last columns SUM=10 + ) + + app = DataTableApp() + async with app.run_test(): + table = app.query_one(DataTable) + + for col in row_data[0]: + table.add_column(col, key=col) + table.add_rows(row_data[1:]) + + table.sort("A", "C", key=custom_sort) + sorted_row_data = sorted(row_data[1:], key=lambda row: row[0] + row[-1]) + for i, row in enumerate(sorted_row_data): + assert table.get_row_at(i) == row + + table.sort("A", "C", key=custom_sort, reverse=True) + sorted_row_data = sorted( + row_data[1:], key=lambda row: row[0] + row[-1], reverse=True + ) + for i, row in enumerate(sorted_row_data): + assert table.get_row_at(i) == row