Skip to content

Commit

Permalink
argument change and functionaloty update
Browse files Browse the repository at this point in the history
Changed back to orinal `columns` argument and added a new `key` argument
which takes a function (or other callable). This allows the PR to NOT BE
a breaking change.
  • Loading branch information
joshbduncan committed Aug 28, 2023
1 parent fcce7e2 commit 69730d8
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 36 deletions.
4 changes: 1 addition & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
### Added
- Added an interface for replacing prompt of an individual option in an `OptionList` https://github.com/Textualize/textual/issues/2603
- Added `DirectoryTree.reload_node` method https://github.com/Textualize/textual/issues/2757

### Changed
- The `DataTable` widget now takes the `by` argument instead of `columns`, allowing the table to also be sorted using a custom function (or other callable). This is a breaking change since it requires all calls to the `sort` method to include an iterable of key(s) (or a singular function/callable). https://github.com/Textualize/textual/issues/2261
- Added `key` argument to the `DataTable.sort()` method, allowing the table to be sorted using a custom function (or other callable). Providing both `columns` and a `key` will limit the data data sent to the key function to only those columns. https://github.com/Textualize/textual/issues/2261

## [0.32.0] - 2023-08-03

Expand Down
34 changes: 26 additions & 8 deletions docs/examples/widgets/data_table_sort.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from rich.text import Text

from textual.app import App, ComposeResult
from textual.events import Click
from textual.widgets import DataTable, Footer

ROWS = [
Expand Down Expand Up @@ -46,32 +47,49 @@ def sort_reverse(self, sort_type: str):
return reverse

def action_sort_by_average_time(self) -> None:
"""Sort DataTable by average of times (via a function)."""
"""Sort DataTable by average of times (via a function) and
passing of column data through positional arguments."""

def sort_by_average_time(row):
_, row_data = row
times = [row_data["time 1"], row_data["time 2"]]
return sum(times) / len(times)
def sort_by_average_time(times):
return sum(n for n in times) / len(times)

table = self.query_one(DataTable)
table.sort(sort_by_average_time, reverse=self.sort_reverse("time"))
table.sort(
"time 1",
"time 2",
key=sort_by_average_time,
reverse=self.sort_reverse("time"),
)

def action_sort_by_last_name(self) -> None:
"""Sort DataTable by last name of swimmer (via a lambda)."""
table = self.query_one(DataTable)
table.sort(
lambda row: row[1]["swimmer"].split()[-1],
key=lambda row: row[1]["swimmer"].split()[-1],
reverse=self.sort_reverse("swimmer"),
)

def action_sort_by_country(self) -> None:
"""Sort DataTable by country which is a `Rich.Text` object."""
table = self.query_one(DataTable)
table.sort(
lambda row: row[1]["country"].plain,
"country",
key=lambda country: country.plain,
reverse=self.sort_reverse("country"),
)

def on_data_table_header_selected(self, event: Click) -> None:
"""Sort `DataTable` items by the clicked column header."""

def sort_by_plain_text(cell):
return cell.plain if isinstance(cell, Text) else cell

column_key = event.column_key
table = self.query_one(DataTable)
table.sort(
column_key, key=sort_by_plain_text, reverse=self.sort_reverse(column_key)
)


app = TableApp()
if __name__ == "__main__":
Expand Down
6 changes: 4 additions & 2 deletions docs/widgets/data_table.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,11 @@ visible as you scroll through the data table.

### Sorting

The `DataTable` can be sorted using the [sort][textual.widgets.DataTable.sort] method. In order to sort your data by a column, you can provide the `key` you supplied to the `add_column` method or a `ColumnKey`. You can then pass this key to the `sort` method inside of an iterable. To sort by multiple columns, pass an iterable with multiple keys to `sort`.
The `DataTable` can be sorted using the [sort][textual.widgets.DataTable.sort] method. In order to sort your data by a column, you can provide the `key` you supplied to the `add_column` method or a `ColumnKey`. You can then pass one more column keys to the `sort` method to sort by one or more columns.

Additionally, you can sort your `DataTable` with a custom function (or other callable). Similar to the `key` parameter of the built-in [sorted()](https://docs.python.org/3/library/functions.html#sorted) function, your function (or other callable) should take a single argument (row) and returns a key to use for sorting purposes.
Additionally, you can sort your `DataTable` with a custom function (or other callable) via the `key` argument. Similar to the `key` parameter of the built-in [sorted()](https://docs.python.org/3/library/functions.html#sorted) function, your function (or other callable) should take a single argument (row) and return a key to use for sorting purposes.

Providing both `columns` and `key` will limit the row information sent to your `key` function (or other callable) to only the columns specified.

=== "Output"

Expand Down
28 changes: 20 additions & 8 deletions src/textual/widgets/_data_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2107,15 +2107,18 @@ def _get_fixed_offset(self) -> Spacing:

def sort(
self,
by: Iterable[ColumnKey | str] | Callable,
*columns: ColumnKey | str,
key: Callable | None = None,
reverse: bool = False,
) -> Self:
"""Sort the rows in the `DataTable` by one or more column keys or a
key function (or other callable).
key function (or other callable). If both columns and a key function
are specified, only data from those columns will sent to the key function.
Args:
by: One or more columns to sort by the values by, or a function
(or other callable) that returns a key to use for sorting purposes.
columns: One or more columns to sort by the values in.
key: A function (or other callable) that returns a key to
use for sorting purposes.
reverse: If True, the sort order will be reversed.
Returns:
Expand All @@ -2126,13 +2129,22 @@ def sort_by_column_keys(
row: tuple[RowKey, dict[ColumnKey | str, CellType]]
) -> Any:
_, row_data = row
result = itemgetter(*by)(row_data)
result = itemgetter(*columns)(row_data)
return result

key = by if isinstance(by, Callable) else sort_by_column_keys
ordered_rows = sorted(self._data.items(), key=key, reverse=reverse)
_key = key
if key and columns:

def _key(row):
return key(itemgetter(*columns)(row[1]))

ordered_rows = sorted(
self._data.items(),
key=_key if _key else sort_by_column_keys,
reverse=reverse,
)
self._row_locations = TwoWayDict(
{key: new_index for new_index, (key, _) in enumerate(ordered_rows)}
{row_key: new_index for new_index, (row_key, _) in enumerate(ordered_rows)}
)
self._update_count += 1
self.refresh()
Expand Down
2 changes: 1 addition & 1 deletion tests/snapshot_tests/snapshot_apps/data_table_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def on_mount(self) -> None:

def action_sort(self):
table = self.query_one(DataTable)
table.sort(["time", "lane"])
table.sort("time", "lane")


app = TableApp()
Expand Down
63 changes: 49 additions & 14 deletions tests/test_data_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ async def test_get_row():
assert table.get_row(second_row) == [3, 2, 1]

# Even if row positions change, keys should always refer to same rows.
table.sort([b])
table.sort(b)
assert table.get_row(first_row) == [2, 4, 1]
assert table.get_row(second_row) == [3, 2, 1]

Expand All @@ -502,7 +502,7 @@ async def test_get_row_at():
assert table.get_row_at(1) == [3, 2, 1]

# If we sort, then the rows present at the indices *do* change!
table.sort([b])
table.sort(b)

# Since we sorted on column "B", the rows at indices 0 and 1 are swapped.
assert table.get_row_at(0) == [3, 2, 1]
Expand Down Expand Up @@ -878,7 +878,7 @@ async def test_sort_coordinate_and_key_access():
assert table.get_cell_at(Coordinate(1, 0)) == 1
assert table.get_cell_at(Coordinate(2, 0)) == 2

table.sort([column])
table.sort(column)

# The keys still refer to the same cells...
assert table.get_cell(row_one, column) == 1
Expand Down Expand Up @@ -911,7 +911,7 @@ async def test_sort_reverse_coordinate_and_key_access():
assert table.get_cell_at(Coordinate(1, 0)) == 1
assert table.get_cell_at(Coordinate(2, 0)) == 2

table.sort([column], reverse=True)
table.sort(column, reverse=True)

# The keys still refer to the same cells...
assert table.get_cell(row_one, column) == 1
Expand Down Expand Up @@ -1176,22 +1176,22 @@ async def test_sort_by_multiple_columns():
assert table.get_row_at(1) == [2, 9, 5]
assert table.get_row_at(2) == [1, 1, 9]

table.sort([a, b, c])
table.sort(a, b, c)
assert table.get_row_at(0) == [1, 1, 9]
assert table.get_row_at(1) == [1, 3, 8]
assert table.get_row_at(2) == [2, 9, 5]

table.sort([a, c, b])
table.sort(a, c, b)
assert table.get_row_at(0) == [1, 3, 8]
assert table.get_row_at(1) == [1, 1, 9]
assert table.get_row_at(2) == [2, 9, 5]

table.sort([c, a, b], reverse=True)
table.sort(c, a, b, reverse=True)
assert table.get_row_at(0) == [1, 1, 9]
assert table.get_row_at(1) == [1, 3, 8]
assert table.get_row_at(2) == [2, 9, 5]

table.sort([a, c])
table.sort(a, c)
assert table.get_row_at(0) == [1, 3, 8]
assert table.get_row_at(1) == [1, 1, 9]
assert table.get_row_at(2) == [2, 9, 5]
Expand Down Expand Up @@ -1220,12 +1220,12 @@ def custom_sort(row):
assert table.get_row_at(i) == row
assert sum(table.get_row_at(i)) == sum(row)

table.sort(custom_sort)
table.sort(key=custom_sort)
sorted_row_data = sorted(row_data, key=sum)
for i, row in enumerate(sorted_row_data):
assert table.get_row_at(i) == row

table.sort(custom_sort, reverse=True)
table.sort(key=custom_sort, reverse=True)
sorted_row_data = sorted(row_data, key=sum, reverse=True)
for i, row in enumerate(sorted_row_data):
assert table.get_row_at(i) == row
Expand All @@ -1249,12 +1249,12 @@ async def test_sort_by_function_sum_lambda():
assert table.get_row_at(i) == row
assert sum(table.get_row_at(i)) == sum(row)

table.sort(lambda row: sum(row[1].values()))
table.sort(key=lambda row: sum(row[1].values()))
sorted_row_data = sorted(row_data, key=sum)
for i, row in enumerate(sorted_row_data):
assert table.get_row_at(i) == row

table.sort(lambda row: sum(row[1].values()), reverse=True)
table.sort(key=lambda row: sum(row[1].values()), reverse=True)
sorted_row_data = sorted(row_data, key=sum, reverse=True)
for i, row in enumerate(sorted_row_data):
assert table.get_row_at(i) == row
Expand Down Expand Up @@ -1284,12 +1284,47 @@ def custom_sort(row):
assert table.get_row_at(1) == ["Albert Douglass", "16/07/23"]
assert table.get_row_at(2) == ["Jane Doe", "25/12/22"]

table.sort(custom_sort)
table.sort(key=custom_sort)
assert table.get_row_at(0) == ["Jane Doe", "25/12/22"]
assert table.get_row_at(1) == ["Albert Douglass", "16/07/23"]
assert table.get_row_at(2) == ["Doug Johnson", "26/07/23"]

table.sort(custom_sort, reverse=True)
table.sort(key=custom_sort, reverse=True)
assert table.get_row_at(0) == ["Doug Johnson", "26/07/23"]
assert table.get_row_at(1) == ["Albert Douglass", "16/07/23"]
assert table.get_row_at(2) == ["Jane Doe", "25/12/22"]


async def test_sort_by_columns_and_function():
"""Test sorting a `DataTable` using a custom sort function and
only supplying data from specific columns."""

def custom_sort(nums):
return sum(n for n in nums)

row_data = (
["A", "B", "C"], # Column headers
[1, 3, 8], # First and last columns SUM=9
[2, 9, 5], # First and last columns SUM=7
[1, 1, 9], # First and last columns SUM=10
)

app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)

for col in row_data[0]:
table.add_column(col, key=col)
table.add_rows(row_data[1:])

table.sort("A", "C", key=custom_sort)
sorted_row_data = sorted(row_data[1:], key=lambda row: row[0] + row[-1])
for i, row in enumerate(sorted_row_data):
assert table.get_row_at(i) == row

table.sort("A", "C", key=custom_sort, reverse=True)
sorted_row_data = sorted(
row_data[1:], key=lambda row: row[0] + row[-1], reverse=True
)
for i, row in enumerate(sorted_row_data):
assert table.get_row_at(i) == row

0 comments on commit 69730d8

Please sign in to comment.