Skip to content

Commit

Permalink
requested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
joshbduncan committed Sep 21, 2023
1 parent 312f9ac commit 5f71d0f
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 59 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
### Added
- Added an interface for replacing prompt of an individual option in an `OptionList` https://github.com/Textualize/textual/issues/2603
- Added `DirectoryTree.reload_node` method https://github.com/Textualize/textual/issues/2757
- Added `key` argument to the `DataTable.sort()` method, allowing the table to be sorted using a custom function (or other callable). Providing both `columns` and a `key` will limit the data data sent to the key function to only those columns. https://github.com/Textualize/textual/issues/2261
- Added `key` argument to the `DataTable.sort()` method, allowing the table to be sorted using a custom function (or other callable) https://github.com/Textualize/textual/issues/2261

## [0.32.0] - 2023-08-03

Expand Down
11 changes: 6 additions & 5 deletions src/textual/widgets/_data_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2108,7 +2108,7 @@ def _get_fixed_offset(self) -> Spacing:
def sort(
self,
*columns: ColumnKey | str,
key: Callable | None = None,
key: Callable[[Any], Any] | None = None,
reverse: bool = False,
) -> Self:
"""Sort the rows in the `DataTable` by one or more column keys or a
Expand All @@ -2133,14 +2133,15 @@ def sort_by_column_keys(
return result

_key = key
if key and columns:
if key:

def _key(row):
return key(itemgetter(*columns)(row[1]))
def _key(row: tuple[RowKey, dict[ColumnKey | str, CellType]]) -> Any:
_, row_data = row
return key(itemgetter(*columns)(row_data))

ordered_rows = sorted(
self._data.items(),
key=_key if _key else sort_by_column_keys,
key=_key if key is not None else sort_by_column_keys,
reverse=reverse,
)
self._row_locations = TwoWayDict(
Expand Down
69 changes: 16 additions & 53 deletions tests/test_data_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,6 +1165,7 @@ async def test_unset_hover_highlight_when_no_table_cell_under_mouse():

async def test_sort_by_multiple_columns():
"""Test sorting a `DataTable` my multiple columns."""

app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)
Expand Down Expand Up @@ -1200,9 +1201,8 @@ async def test_sort_by_multiple_columns():
async def test_sort_by_function_sum():
"""Test sorting a `DataTable` using a custom sort function."""

def custom_sort(row):
_, row_data = row
return sum(row_data.values())
def custom_sort(row_data):
return sum(row_data)

row_data = (
[1, 3, 8], # SUM=12
Expand All @@ -1220,19 +1220,19 @@ def custom_sort(row):
assert table.get_row_at(i) == row
assert sum(table.get_row_at(i)) == sum(row)

table.sort(key=custom_sort)
table.sort(a, b, c, key=custom_sort)
sorted_row_data = sorted(row_data, key=sum)
for i, row in enumerate(sorted_row_data):
assert table.get_row_at(i) == row

table.sort(key=custom_sort, reverse=True)
table.sort(a, b, c, key=custom_sort, reverse=True)
sorted_row_data = sorted(row_data, key=sum, reverse=True)
for i, row in enumerate(sorted_row_data):
assert table.get_row_at(i) == row


async def test_sort_by_function_sum_lambda():
"""Test sorting a `DataTable` using a custom sort lambda."""
async def test_sort_by_lambda_function():
"""Test sorting a `DataTable` using lambda function."""
row_data = (
[1, 3, 8], # SUM=12
[2, 9, 5], # SUM=16
Expand All @@ -1249,12 +1249,12 @@ async def test_sort_by_function_sum_lambda():
assert table.get_row_at(i) == row
assert sum(table.get_row_at(i)) == sum(row)

table.sort(key=lambda row: sum(row[1].values()))
table.sort(a, b, c, key=lambda row_data: sum(row_data))
sorted_row_data = sorted(row_data, key=sum)
for i, row in enumerate(sorted_row_data):
assert table.get_row_at(i) == row

table.sort(key=lambda row: sum(row[1].values()), reverse=True)
table.sort(a, b, c, key=lambda row_data: sum(row_data), reverse=True)
sorted_row_data = sorted(row_data, key=sum, reverse=True)
for i, row in enumerate(sorted_row_data):
assert table.get_row_at(i) == row
Expand All @@ -1267,10 +1267,8 @@ async def test_sort_by_function_date():
Example: 26/07/23 == 2023-07-26
"""

def custom_sort(row):
_, row_data = row
result = itemgetter(birthdate)(row_data)
d, m, y = result.split("/")
def custom_sort(birthdate):
d, m, y = birthdate.split("/")
return f"{y}{m}{d}"

app = DataTableApp()
Expand All @@ -1284,55 +1282,20 @@ def custom_sort(row):
assert table.get_row_at(1) == ["Albert Douglass", "16/07/23"]
assert table.get_row_at(2) == ["Jane Doe", "25/12/22"]

table.sort(key=custom_sort)
table.sort(birthdate, key=custom_sort)
assert table.get_row_at(0) == ["Jane Doe", "25/12/22"]
assert table.get_row_at(1) == ["Albert Douglass", "16/07/23"]
assert table.get_row_at(2) == ["Doug Johnson", "26/07/23"]

table.sort(key=custom_sort, reverse=True)
table.sort(birthdate, key=custom_sort, reverse=True)
assert table.get_row_at(0) == ["Doug Johnson", "26/07/23"]
assert table.get_row_at(1) == ["Albert Douglass", "16/07/23"]
assert table.get_row_at(2) == ["Jane Doe", "25/12/22"]


async def test_sort_by_columns_and_function():
"""Test sorting a `DataTable` using a custom sort function and
only supplying data from specific columns."""

def custom_sort(nums):
return sum(n for n in nums)

row_data = (
["A", "B", "C"], # Column headers
[1, 3, 8], # First and last columns SUM=9
[2, 9, 5], # First and last columns SUM=7
[1, 1, 9], # First and last columns SUM=10
)

app = DataTableApp()
async with app.run_test():
table = app.query_one(DataTable)

for col in row_data[0]:
table.add_column(col, key=col)
table.add_rows(row_data[1:])

table.sort("A", "C", key=custom_sort)
sorted_row_data = sorted(row_data[1:], key=lambda row: row[0] + row[-1])
for i, row in enumerate(sorted_row_data):
assert table.get_row_at(i) == row

table.sort("A", "C", key=custom_sort, reverse=True)
sorted_row_data = sorted(
row_data[1:], key=lambda row: row[0] + row[-1], reverse=True
)
for i, row in enumerate(sorted_row_data):
assert table.get_row_at(i) == row


async def test_sort_by_multiple_columns_and_function():
"""Test sorting a `DataTable` using a custom sort function and
only supplying data from specific columns."""
async def test_sort_by_function_retuning_multiple_values():
"""Test sorting a `DataTable` using a custom sort function
that returns multiple for the row to be sorted by."""

def custom_sort(row_data):
name, *scores = row_data
Expand Down

0 comments on commit 5f71d0f

Please sign in to comment.