Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataTable sort by function (or other callable) #3090

Merged
merged 13 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

- Add Document `get_index_from_location` / `get_location_from_index` https://github.com/Textualize/textual/pull/3410
- Add setter for `TextArea.text` https://github.com/Textualize/textual/discussions/3525
- Added `key` argument to the `DataTable.sort()` method, allowing the table to be sorted using a custom function (or other callable) https://github.com/Textualize/textual/pull/3090
- Added `initial` to all css rules, which restores default (i.e. value from DEFAULT_CSS) https://github.com/Textualize/textual/pull/3566
- Added HorizontalPad to pad.py https://github.com/Textualize/textual/pull/3571
- Added `AwaitComplete` class, to be used for optionally awaitable return values https://github.com/Textualize/textual/pull/3498


### Changed

Expand All @@ -46,15 +51,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- Breaking change: empty rules now result in an error https://github.com/Textualize/textual/pull/3566
- Improved startup time by caching CSS parsing https://github.com/Textualize/textual/pull/3575

### Added

- Added `initial` to all css rules, which restores default (i.e. value from DEFAULT_CSS) https://github.com/Textualize/textual/pull/3566
- Added HorizontalPad to pad.py https://github.com/Textualize/textual/pull/3571

### Added

- Added `AwaitComplete` class, to be used for optionally awaitable return values https://github.com/Textualize/textual/pull/3498

## [0.40.0] - 2023-10-11

### Added
Expand Down Expand Up @@ -248,7 +244,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

- DescendantBlur and DescendantFocus can now be used with @on decorator


## [0.32.0] - 2023-08-03

### Added
Expand Down
92 changes: 92 additions & 0 deletions docs/examples/widgets/data_table_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from rich.text import Text

from textual.app import App, ComposeResult
from textual.widgets import DataTable, Footer

ROWS = [
("lane", "swimmer", "country", "time 1", "time 2"),
(4, "Joseph Schooling", Text("Singapore", style="italic"), 50.39, 51.84),
(2, "Michael Phelps", Text("United States", style="italic"), 50.39, 51.84),
(5, "Chad le Clos", Text("South Africa", style="italic"), 51.14, 51.73),
(6, "László Cseh", Text("Hungary", style="italic"), 51.14, 51.58),
(3, "Li Zhuhao", Text("China", style="italic"), 51.26, 51.26),
(8, "Mehdy Metella", Text("France", style="italic"), 51.58, 52.15),
(7, "Tom Shields", Text("United States", style="italic"), 51.73, 51.12),
(1, "Aleksandr Sadovnikov", Text("Russia", style="italic"), 51.84, 50.85),
(10, "Darren Burns", Text("Scotland", style="italic"), 51.84, 51.55),
]


class TableApp(App):
BINDINGS = [
("a", "sort_by_average_time", "Sort By Average Time"),
("n", "sort_by_last_name", "Sort By Last Name"),
("c", "sort_by_country", "Sort By Country"),
("d", "sort_by_columns", "Sort By Columns (Only)"),
]

current_sorts: set = set()

def compose(self) -> ComposeResult:
yield DataTable()
yield Footer()

def on_mount(self) -> None:
table = self.query_one(DataTable)
for col in ROWS[0]:
table.add_column(col, key=col)
table.add_rows(ROWS[1:])

def sort_reverse(self, sort_type: str):
"""Determine if `sort_type` is ascending or descending."""
reverse = sort_type in self.current_sorts
if reverse:
self.current_sorts.remove(sort_type)
else:
self.current_sorts.add(sort_type)
return reverse

def action_sort_by_average_time(self) -> None:
"""Sort DataTable by average of times (via a function) and
passing of column data through positional arguments."""

def sort_by_average_time_then_last_name(row_data):
name, *scores = row_data
return (sum(scores) / len(scores), name.split()[-1])

table = self.query_one(DataTable)
table.sort(
"swimmer",
"time 1",
"time 2",
key=sort_by_average_time_then_last_name,
reverse=self.sort_reverse("time"),
)

def action_sort_by_last_name(self) -> None:
"""Sort DataTable by last name of swimmer (via a lambda)."""
table = self.query_one(DataTable)
table.sort(
"swimmer",
key=lambda swimmer: swimmer.split()[-1],
reverse=self.sort_reverse("swimmer"),
)

def action_sort_by_country(self) -> None:
"""Sort DataTable by country which is a `Rich.Text` object."""
table = self.query_one(DataTable)
table.sort(
"country",
key=lambda country: country.plain,
reverse=self.sort_reverse("country"),
)

def action_sort_by_columns(self) -> None:
"""Sort DataTable without a key."""
table = self.query_one(DataTable)
table.sort("swimmer", "lane", reverse=self.sort_reverse("columns"))


app = TableApp()
if __name__ == "__main__":
app.run()
21 changes: 16 additions & 5 deletions docs/widgets/data_table.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,22 @@ visible as you scroll through the data table.

### Sorting

The `DataTable` can be sorted using the [sort][textual.widgets.DataTable.sort] method.
In order to sort your data by a column, you must have supplied a `key` to the `add_column` method
when you added it.
You can then pass this key to the `sort` method to sort by that column.
Additionally, you can sort by multiple columns by passing multiple keys to `sort`.
The `DataTable` can be sorted using the [sort][textual.widgets.DataTable.sort] method. In order to sort your data by a column, you can provide the `key` you supplied to the `add_column` method or a `ColumnKey`. You can then pass one more column keys to the `sort` method to sort by one or more columns.

Additionally, you can sort your `DataTable` with a custom function (or other callable) via the `key` argument. Similar to the `key` parameter of the built-in [sorted()](https://docs.python.org/3/library/functions.html#sorted) function, your function (or other callable) should take a single argument (row) and return a key to use for sorting purposes.

Providing both `columns` and `key` will limit the row information sent to your `key` function (or other callable) to only the columns specified.

=== "Output"

```{.textual path="docs/examples/widgets/data_table_sort.py"}
```

=== "data_table_sort.py"

```python
--8<-- "docs/examples/widgets/data_table_sort.py"
```

### Labelled rows

Expand Down
26 changes: 18 additions & 8 deletions src/textual/widgets/_data_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass
from itertools import chain, zip_longest
from operator import itemgetter
from typing import Any, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast
from typing import Any, Callable, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast

import rich.repr
from rich.console import RenderableType
Expand Down Expand Up @@ -2348,30 +2348,40 @@ def _get_fixed_offset(self) -> Spacing:
def sort(
self,
*columns: ColumnKey | str,
key: Callable[[Any], Any] | None = None,
reverse: bool = False,
) -> Self:
"""Sort the rows in the `DataTable` by one or more column keys.
"""Sort the rows in the `DataTable` by one or more column keys or a
key function (or other callable). If both columns and a key function
are specified, only data from those columns will sent to the key function.

Args:
columns: One or more columns to sort by the values in.
key: A function (or other callable) that returns a key to
use for sorting purposes.
reverse: If True, the sort order will be reversed.

Returns:
The `DataTable` instance.
"""

def sort_by_column_keys(
row: tuple[RowKey, dict[ColumnKey | str, CellType]]
) -> Any:
def key_wrapper(row: tuple[RowKey, dict[ColumnKey | str, CellType]]) -> Any:
_, row_data = row
result = itemgetter(*columns)(row_data)
if columns:
result = itemgetter(*columns)(row_data)
else:
result = tuple(row_data.values())
if key is not None:
return key(result)
return result

ordered_rows = sorted(
self._data.items(), key=sort_by_column_keys, reverse=reverse
self._data.items(),
key=key_wrapper,
reverse=reverse,
)
self._row_locations = TwoWayDict(
{key: new_index for new_index, (key, _) in enumerate(ordered_rows)}
{row_key: new_index for new_index, (row_key, _) in enumerate(ordered_rows)}
)
self._update_count += 1
self.refresh()
Expand Down
94 changes: 94 additions & 0 deletions tests/test_data_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,6 +1197,100 @@ async def test_unset_hover_highlight_when_no_table_cell_under_mouse():
assert not table._show_hover_cursor


async def test_sort_by_all_columns_no_key():
"""Test sorting a `DataTable` by all columns."""

app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
a, b, c = table.add_columns("A", "B", "C")
table.add_row(1, 3, 8)
table.add_row(2, 9, 5)
table.add_row(1, 1, 9)
assert table.get_row_at(0) == [1, 3, 8]
assert table.get_row_at(1) == [2, 9, 5]
assert table.get_row_at(2) == [1, 1, 9]

table.sort()
assert table.get_row_at(0) == [1, 1, 9]
assert table.get_row_at(1) == [1, 3, 8]
assert table.get_row_at(2) == [2, 9, 5]

table.sort(reverse=True)
assert table.get_row_at(0) == [2, 9, 5]
assert table.get_row_at(1) == [1, 3, 8]
assert table.get_row_at(2) == [1, 1, 9]


async def test_sort_by_multiple_columns_no_key():
"""Test sorting a `DataTable` by multiple columns."""

app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
a, b, c = table.add_columns("A", "B", "C")
table.add_row(1, 3, 8)
table.add_row(2, 9, 5)
table.add_row(1, 1, 9)

table.sort(a, b, c)
assert table.get_row_at(0) == [1, 1, 9]
assert table.get_row_at(1) == [1, 3, 8]
assert table.get_row_at(2) == [2, 9, 5]

table.sort(a, c, b)
assert table.get_row_at(0) == [1, 3, 8]
assert table.get_row_at(1) == [1, 1, 9]
assert table.get_row_at(2) == [2, 9, 5]

table.sort(c, a, b, reverse=True)
assert table.get_row_at(0) == [1, 1, 9]
assert table.get_row_at(1) == [1, 3, 8]
assert table.get_row_at(2) == [2, 9, 5]

table.sort(a, c)
assert table.get_row_at(0) == [1, 3, 8]
assert table.get_row_at(1) == [1, 1, 9]
assert table.get_row_at(2) == [2, 9, 5]


async def test_sort_by_function_sum():
"""Test sorting a `DataTable` using a custom sort function."""

def custom_sort(row_data):
return sum(row_data)

row_data = (
[1, 3, 8], # SUM=12
[2, 9, 5], # SUM=16
[1, 1, 9], # SUM=11
)

app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
a, b, c = table.add_columns("A", "B", "C")
for i, row in enumerate(row_data):
table.add_row(*row)

# Sorting by all columns
table.sort(a, b, c, key=custom_sort)
sorted_row_data = sorted(row_data, key=sum)
for i, row in enumerate(sorted_row_data):
assert table.get_row_at(i) == row

# Passing a sort function but no columns also sorts by all columns
table.sort(key=custom_sort)
sorted_row_data = sorted(row_data, key=sum)
for i, row in enumerate(sorted_row_data):
assert table.get_row_at(i) == row

table.sort(a, b, c, key=custom_sort, reverse=True)
sorted_row_data = sorted(row_data, key=sum, reverse=True)
for i, row in enumerate(sorted_row_data):
assert table.get_row_at(i) == row


@pytest.mark.parametrize(
["cell", "height"],
[
Expand Down
Loading