Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataTable sort by function (or other callable) #3090

Merged
merged 13 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- `MouseMove` events bubble up from widgets. `App` and `Screen` receive `MouseMove` events even if there's no Widget under the cursor. https://github.com/Textualize/textual/issues/2905

### Added
- Added an interface for replacing prompt of an individual option in an `OptionList` https://github.com/Textualize/textual/issues/2603
- Added an interface for replacing prompt of an individual option in an `OptionList` https://github.com/Textualize/textual/issues/2603
- Added `DirectoryTree.reload_node` method https://github.com/Textualize/textual/issues/2757
- Added `key` argument to the `DataTable.sort()` method, allowing the table to be sorted using a custom function (or other callable). Providing both `columns` and a `key` will limit the data data sent to the key function to only those columns. https://github.com/Textualize/textual/issues/2261

## [0.32.0] - 2023-08-03

Expand Down
86 changes: 86 additions & 0 deletions docs/examples/widgets/data_table_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from rich.text import Text

from textual.app import App, ComposeResult
from textual.widgets import DataTable, Footer

ROWS = [
("lane", "swimmer", "country", "time 1", "time 2"),
(4, "Joseph Schooling", Text("Singapore", style="italic"), 50.39, 51.84),
(2, "Michael Phelps", Text("United States", style="italic"), 50.39, 51.84),
(5, "Chad le Clos", Text("South Africa", style="italic"), 51.14, 51.73),
(6, "László Cseh", Text("Hungary", style="italic"), 51.14, 51.58),
(3, "Li Zhuhao", Text("China", style="italic"), 51.26, 51.26),
(8, "Mehdy Metella", Text("France", style="italic"), 51.58, 52.15),
(7, "Tom Shields", Text("United States", style="italic"), 51.73, 51.12),
(1, "Aleksandr Sadovnikov", Text("Russia", style="italic"), 51.84, 50.85),
(10, "Darren Burns", Text("Scotland", style="italic"), 51.84, 51.55),
]


class TableApp(App):
BINDINGS = [
("a", "sort_by_average_time", "Sort By Average Time"),
("n", "sort_by_last_name", "Sort By Last Name"),
("c", "sort_by_country", "Sort By Country"),
]

current_sorts: set = set()

def compose(self) -> ComposeResult:
yield DataTable()
yield Footer()

def on_mount(self) -> None:
table = self.query_one(DataTable)
for col in ROWS[0]:
table.add_column(col, key=col)
table.add_rows(ROWS[1:])

def sort_reverse(self, sort_type: str):
"""Determine if `sort_type` is ascending or descending."""
reverse = sort_type in self.current_sorts
if reverse:
self.current_sorts.remove(sort_type)
else:
self.current_sorts.add(sort_type)
return reverse

def action_sort_by_average_time(self) -> None:
"""Sort DataTable by average of times (via a function) and
passing of column data through positional arguments."""

def sort_by_average_time_then_last_name(row_data):
name, *scores = row_data
return (sum(scores) / len(scores), name.split()[-1])

table = self.query_one(DataTable)
table.sort(
"swimmer",
"time 1",
"time 2",
key=sort_by_average_time_then_last_name,
reverse=self.sort_reverse("time"),
)

def action_sort_by_last_name(self) -> None:
"""Sort DataTable by last name of swimmer (via a lambda)."""
table = self.query_one(DataTable)
table.sort(
"swimmer",
key=lambda swimmer: swimmer.split()[-1],
reverse=self.sort_reverse("swimmer"),
)

def action_sort_by_country(self) -> None:
"""Sort DataTable by country which is a `Rich.Text` object."""
table = self.query_one(DataTable)
table.sort(
"country",
key=lambda country: country.plain,
reverse=self.sort_reverse("country"),
)


app = TableApp()
if __name__ == "__main__":
app.run()
21 changes: 16 additions & 5 deletions docs/widgets/data_table.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,22 @@ visible as you scroll through the data table.

### Sorting

The `DataTable` can be sorted using the [sort][textual.widgets.DataTable.sort] method.
In order to sort your data by a column, you must have supplied a `key` to the `add_column` method
when you added it.
You can then pass this key to the `sort` method to sort by that column.
Additionally, you can sort by multiple columns by passing multiple keys to `sort`.
The `DataTable` can be sorted using the [sort][textual.widgets.DataTable.sort] method. In order to sort your data by a column, you can provide the `key` you supplied to the `add_column` method or a `ColumnKey`. You can then pass one more column keys to the `sort` method to sort by one or more columns.

Additionally, you can sort your `DataTable` with a custom function (or other callable) via the `key` argument. Similar to the `key` parameter of the built-in [sorted()](https://docs.python.org/3/library/functions.html#sorted) function, your function (or other callable) should take a single argument (row) and return a key to use for sorting purposes.

Providing both `columns` and `key` will limit the row information sent to your `key` function (or other callable) to only the columns specified.

=== "Output"

```{.textual path="docs/examples/widgets/data_table_sort.py"}
```

=== "data_table_sort.py"

```python
--8<-- "docs/examples/widgets/data_table_sort.py"
```

### Labelled rows

Expand Down
21 changes: 17 additions & 4 deletions src/textual/widgets/_data_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass
from itertools import chain, zip_longest
from operator import itemgetter
from typing import Any, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast
from typing import Any, Callable, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast

import rich.repr
from rich.console import RenderableType
Expand Down Expand Up @@ -2108,12 +2108,17 @@ def _get_fixed_offset(self) -> Spacing:
def sort(
self,
*columns: ColumnKey | str,
key: Callable | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we narrow the typing for this Callable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since key can accept and return such a wide range of types were you thinking something like key: Callable[..., Any] or key: Callable[[Any], Any]? Neither are very "narrowing" but seem to be pretty common practice from what I have seen elsewhere.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't the key function take a row, which is given as a list of values? In that case, wouldn't this be more precise?

Callable[list[Any], Any]

I was hoping there was some kind of typing.SupportsCompare, but there doesn't seem to be.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but the args going to key won't be a list as args come in as a tuple. And no matter the type, it seems the first type in the Callable typing has to be inside of brackets unless it is .... So, Callable[[Any], Any], or Callable[[list[Any], Any], or `Callable[[tuple[Any]], Any], and so on...

And to further complicate things, since itemgetter is used when columns are provided, either a single value CellType or tuple of multiple values CellType (when someone passes multiple columns to the sort function) are getting passed to the function making me lean towards Any.

Unless I am missing something (which is very possible), I think the options are Callable[[Any], Any] of Callable[[tuple[Any]], Any]. Thoughts?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please forgive my ignorance if this is out of line, would an @overload be appropriate to help narrow the type? Roughly mocked, it could look like the following:

@overload
def sort(
    self,
    columns: ColumnKey | str,
    *,
    key: Callable[[CellType], Any] | None = None,
    reverse: bool = False,
) -> Self:
    ...

@overload
def sort(
    self,
    *columns: ColumnKey | str,
    key: Callable[[tuple[CellType, ...]], Any] | None = None,
    reverse: bool = False,
) -> Self:
    ...

def sort(
    self,
    *columns: ColumnKey | str,
    key: Callable[[Any], Any] | None = None,
    reverse: bool = False,
) -> Self:
    # Implementation

Though the type hints I'm receiving in test_data_table.py are appropriately enhanced and the new/updated test_sort_...() tests are still passing on my machine, there are two potential issues with my naive approach that I can see immediately:

  1. These overloads are just fancy lies and I don't think the relocation of the * in the fist/singular @overload is good behavior. My type checker is yelling at me in the _data_table.py file on the implementation of sort() that Overloaded implementation is not consistent with signature of overload 1 due to the movement of the * between the singular overload and the implementation. This motion hides the fact that the singular columns: ColumnKey | str param is actually getting wrapped up into a tuple by the implementation's *columns: ColumnKey | str.
  2. The first @overload, in its singular form, may become the first completion suggestion and how a new user first interacts with sort(). This may lead them to the incorrect conclusion that, despite columns being plural and clear documentation, sort() is limited to a single column input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I honestly haven't used the @overload decorator much and that would probably be a directional decision for the maintainers.

As for the Callable typing... I have looked into it (too much) and there doesn't seem to be a general consensus. After lots of GitHub sleuthing, I have found most packages just go with Callable[[Any], Any] or Callable[..., Any] which is equivalent to just Callable. And I can't come up with anything better...

On the inputs/parameters side of the callable type, CellType is loosely defined so no matter if you use that, it just resolves to Any anyway.

And on the return side of the callable, since the method is called sort, I think it is implied that the return values should be "comparable", but that obviously leaves many options (which is probably best typed with Any).

So, all that to say, 🤷‍♂️...

reverse: bool = False,
) -> Self:
"""Sort the rows in the `DataTable` by one or more column keys.
"""Sort the rows in the `DataTable` by one or more column keys or a
key function (or other callable). If both columns and a key function
are specified, only data from those columns will sent to the key function.

Args:
columns: One or more columns to sort by the values in.
key: A function (or other callable) that returns a key to
use for sorting purposes.
reverse: If True, the sort order will be reversed.

Returns:
Expand All @@ -2127,11 +2132,19 @@ def sort_by_column_keys(
result = itemgetter(*columns)(row_data)
return result

_key = key
if key and columns:

def _key(row):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So row is a tuple of row_key and row data?

Do we need the row key in there? I'm struggling to imagine a scenario where that might be neccesary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow... The original code provided a tuple of tuple[RowKey, dict[ColumnKey | str, CellType]] to the default sort_by_column_keys() function, then after sorting, it uses the RowKey part of the now sorted tuples to update the _row_locations dictionary. I just followed that same logic when adding the catch for key and columns. I could re-write the catch to be more clear and in line with the default sort_by_column_keys() function like below.

if key and columns:

    def _key(row):
        _, row_data = row
        return key(itemgetter(*columns)(row_data))

Let me know if I am missing what you are point out?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see how the row_key is used now. The key function gets a tuple of row key, and row data, but it discards the row key (which you unpack as _).

I think it would be surprising to have the row key there (and its not documented in the signature). Could we allow the key function to accept just the row data? That's what I would expect to be the default.

I guess you would need to wrap the key function. Something like the following:

def key_wrapper(row):
    row_key, row_data = row
    return (row_key, key(row_data))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, can I back up a bit and get your thoughts to make sure we are on the same page... I think we are working toward the same thing since I am already "wrapping" the key function to only send the row_data to the original user-supplied key function...

First things first. I should have asked this in the beginning, I guess. The _data_table.sort() API in main raises an exception if no columns are passed since the base sort_by_column_keys() function uses itemgetter to get each item from columns. When a key is provided should the same logic be followed, or should all row data be sent to the provided key function?

If the same exception should be raised if no columns are provided, then this "wraps" the user-supplied key function and only sends the actual row data.

_key = key
if key:

    def _key(row: tuple[RowKey, dict[ColumnKey | str, CellType]]) -> Any:
        _, row_data = row
        return key(itemgetter(*columns)(row_data))

If all data should be sent in the absence of columns, then a ternary to send the correct data should do the trick.

_key = key
if key:

    def _key(row):
        _, row_data = row
        return key(itemgetter(*columns)(row_data) if columns else tuple(row_data.values()))

In both cases, only the actual row_data is being passed to the user-supplied key function, the only reason for the row_key tagging along is to be able to update the TwoWayDict the API has set up to track all of the data. To my knowledge, if you don't send along with the RowKey to the built-in sorted method, you will have to do more work to reconstruct the self._row_locations dict after sorting.

ordered_rows = sorted(
    self._data.items(),
    key=_key if key is not None else sort_by_column_keys,
    reverse=reverse,
)
self._row_locations = TwoWayDict(
    {row_key: new_index for new_index, (row_key, _) in enumerate(ordered_rows)}
)

return key(itemgetter(*columns)(row[1]))

ordered_rows = sorted(
self._data.items(), key=sort_by_column_keys, reverse=reverse
self._data.items(),
key=_key if _key else sort_by_column_keys,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have an explicit check for is not None here. For the (admittedly unlikely) scenario of a callable object with a __bool__ that returns False.

reverse=reverse,
)
self._row_locations = TwoWayDict(
{key: new_index for new_index, (key, _) in enumerate(ordered_rows)}
{row_key: new_index for new_index, (row_key, _) in enumerate(ordered_rows)}
)
self._update_count += 1
self.refresh()
Expand Down
Loading