diff --git a/src/textual/_two_way_dict.py b/src/textual/_two_way_dict.py index 123e1848d8..ac0ffe16ef 100644 --- a/src/textual/_two_way_dict.py +++ b/src/textual/_two_way_dict.py @@ -32,7 +32,7 @@ def __delitem__(self, key: Key) -> None: def __iter__(self): return iter(self._forward) - def get(self, key: Key) -> Value: + def get(self, key: Key) -> Value | None: """Given a key, efficiently lookup and return the associated value. Args: @@ -43,7 +43,7 @@ def get(self, key: Key) -> Value: """ return self._forward.get(key) - def get_key(self, value: Value) -> Key: + def get_key(self, value: Value) -> Key | None: """Given a value, efficiently lookup and return the associated key. Args: diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 0cd8cf293b..e1a65d7f8c 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from itertools import chain, zip_longest from operator import itemgetter -from typing import Any, Callable, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast +from typing import Any, Callable, ClassVar, Iterable, NamedTuple, cast import rich.repr from rich.console import RenderableType @@ -43,8 +43,6 @@ ) CursorType = Literal["cell", "row", "column", "none"] """The valid types of cursors for [`DataTable.cursor_type`][textual.widgets.DataTable.cursor_type].""" -CellType = TypeVar("CellType") -"""Type used for cells in the DataTable.""" _DEFAULT_CELL_X_PADDING = 1 """Default padding to use on each side of a column in the data table.""" @@ -183,7 +181,7 @@ class Column: content_width: int = 0 auto_width: bool = False - def get_render_width(self, data_table: DataTable[Any]) -> int: + def get_render_width(self, data_table: DataTable) -> int: """Width, in cells, required to render the column with padding included. Args: @@ -214,7 +212,7 @@ class RowRenderables(NamedTuple): cells: list[RenderableType] -class DataTable(ScrollView, Generic[CellType], can_focus=True): +class DataTable(ScrollView, can_focus=True): """A tabular widget that contains data.""" BINDINGS: ClassVar[list[BindingType]] = [ @@ -355,13 +353,13 @@ class CellHighlighted(Message): def __init__( self, data_table: DataTable, - value: CellType, + value: object, coordinate: Coordinate, cell_key: CellKey, ) -> None: self.data_table = data_table """The data table.""" - self.value: CellType = value + self.value: object = value """The value in the highlighted cell.""" self.coordinate: Coordinate = coordinate """The coordinate of the highlighted cell.""" @@ -390,13 +388,13 @@ class CellSelected(Message): def __init__( self, data_table: DataTable, - value: CellType, + value: object, coordinate: Coordinate, cell_key: CellKey, ) -> None: self.data_table = data_table """The data table.""" - self.value: CellType = value + self.value: object = value """The value in the cell that was selected.""" self.coordinate: Coordinate = coordinate """The coordinate of the cell that was selected.""" @@ -641,7 +639,7 @@ def __init__( """ super().__init__(name=name, id=id, classes=classes, disabled=disabled) - self._data: dict[RowKey, dict[ColumnKey, CellType]] = {} + self._data: dict[RowKey, dict[ColumnKey, object]] = {} """Contains the cells of the table, indexed by row key and column key. The final positioning of a cell on screen cannot be determined solely by this structure. Instead, we must check _row_locations and _column_locations to find @@ -776,7 +774,7 @@ def update_cell( self, row_key: RowKey | str, column_key: ColumnKey | str, - value: CellType, + value: object, *, update_width: bool = False, ) -> None: @@ -817,7 +815,7 @@ def update_cell( self.refresh() def update_cell_at( - self, coordinate: Coordinate, value: CellType, *, update_width: bool = False + self, coordinate: Coordinate, value: object, *, update_width: bool = False ) -> None: """Update the content inside the cell currently occupying the given coordinate. @@ -833,7 +831,7 @@ def update_cell_at( row_key, column_key = self.coordinate_to_cell_key(coordinate) self.update_cell(row_key, column_key, value, update_width=update_width) - def get_cell(self, row_key: RowKey | str, column_key: ColumnKey | str) -> CellType: + def get_cell(self, row_key: RowKey | str, column_key: ColumnKey | str) -> object: """Given a row key and column key, return the value of the corresponding cell. Args: @@ -843,6 +841,8 @@ def get_cell(self, row_key: RowKey | str, column_key: ColumnKey | str) -> CellTy Returns: The value of the cell identified by the row and column keys. """ + row_key, column_key = self._ensure_keys(row_key, column_key) + try: cell_value = self._data[row_key][column_key] except KeyError: @@ -851,7 +851,7 @@ def get_cell(self, row_key: RowKey | str, column_key: ColumnKey | str) -> CellTy ) return cell_value - def get_cell_at(self, coordinate: Coordinate) -> CellType: + def get_cell_at(self, coordinate: Coordinate) -> object: """Get the value from the cell occupying the given coordinate. Args: @@ -881,6 +881,8 @@ def get_cell_coordinate( Raises: CellDoesNotExist: If the specified cell does not exist. """ + row_key, column_key = self._ensure_keys(row_key, column_key) + if ( row_key not in self._row_locations or column_key not in self._column_locations @@ -892,7 +894,7 @@ def get_cell_coordinate( column_index = self._column_locations.get(column_key) return Coordinate(row_index, column_index) - def get_row(self, row_key: RowKey | str) -> list[CellType]: + def get_row(self, row_key: RowKey | str) -> list[object]: """Get the values from the row identified by the given row key. Args: @@ -904,15 +906,21 @@ def get_row(self, row_key: RowKey | str) -> list[CellType]: Raises: RowDoesNotExist: When there is no row corresponding to the key. """ + if isinstance(row_key, str): + row_key = RowKey(row_key) + if row_key not in self._row_locations: raise RowDoesNotExist(f"Row key {row_key!r} is not valid.") - cell_mapping: dict[ColumnKey, CellType] = self._data.get(row_key, {}) - ordered_row: list[CellType] = [ - cell_mapping[column.key] for column in self.ordered_columns - ] + cell_mapping: dict[ColumnKey, object] | None = self._data.get(row_key) + if cell_mapping is not None: + ordered_row: list[object] = [ + cell_mapping[column.key] for column in self.ordered_columns + ] + else: + ordered_row = [] return ordered_row - def get_row_at(self, row_index: int) -> list[CellType]: + def get_row_at(self, row_index: int) -> list[object]: """Get the values from the cells in a row at a given index. This will return the values from a row based on the rows _current position_ in the table. @@ -943,11 +951,15 @@ def get_row_index(self, row_key: RowKey | str) -> int: Raises: RowDoesNotExist: If the row key does not exist. """ + if isinstance(row_key, str): + row_key = RowKey(row_key) + if row_key not in self._row_locations: raise RowDoesNotExist(f"No row exists for row_key={row_key!r}") + return self._row_locations.get(row_key) - def get_column(self, column_key: ColumnKey | str) -> Iterable[CellType]: + def get_column(self, column_key: ColumnKey | str) -> Iterable[object]: """Get the values from the column identified by the given column key. Args: @@ -959,6 +971,9 @@ def get_column(self, column_key: ColumnKey | str) -> Iterable[CellType]: Raises: ColumnDoesNotExist: If there is no column corresponding to the key. """ + if isinstance(column_key, str): + column_key = ColumnKey(column_key) + if column_key not in self._column_locations: raise ColumnDoesNotExist(f"Column key {column_key!r} is not valid.") @@ -967,7 +982,7 @@ def get_column(self, column_key: ColumnKey | str) -> Iterable[CellType]: row_key = row_metadata.key yield data[row_key][column_key] - def get_column_at(self, column_index: int) -> Iterable[CellType]: + def get_column_at(self, column_index: int) -> Iterable[object]: """Get the values from the column at a given index. Args: @@ -997,10 +1012,24 @@ def get_column_index(self, column_key: ColumnKey | str) -> int: Raises: ColumnDoesNotExist: If the column key does not exist. """ + if isinstance(column_key, str): + column_key = ColumnKey(column_key) + if column_key not in self._column_locations: raise ColumnDoesNotExist(f"No column exists for column_key={column_key!r}") + return self._column_locations.get(column_key) + def _ensure_keys( + self, row_key: RowKey | str, column_key: ColumnKey | str + ) -> tuple[RowKey, ColumnKey]: + """Convert row/column keys which may be strings into the dedicated RowKey/ColumnKey objects.""" + if isinstance(row_key, str): + row_key = RowKey(row_key) + if isinstance(column_key, str): + column_key = ColumnKey(column_key) + return row_key, column_key + def _clear_caches(self) -> None: self._row_render_cache.clear() self._cell_render_cache.clear() @@ -1175,9 +1204,15 @@ def coordinate_to_cell_key(self, coordinate: Coordinate) -> CellKey: """ if not self.is_valid_coordinate(coordinate): raise CellDoesNotExist(f"No cell exists at {coordinate!r}.") + row_index, column_index = coordinate row_key = self._row_locations.get_key(row_index) column_key = self._column_locations.get_key(column_index) + + # We've checked the coordinate is valid via is_valid_coordinate, so row_key and column_key should exist. + assert row_key is not None + assert column_key is not None + return CellKey(row_key, column_key) def _highlight_row(self, row_index: int) -> None: @@ -1478,7 +1513,7 @@ def add_column( *, width: int | None = None, key: str | None = None, - default: CellType | None = None, + default: object = None, ) -> ColumnKey: """Add a column to the table. @@ -1531,7 +1566,7 @@ def add_column( def add_row( self, - *cells: CellType, + *cells: object, height: int | None = 1, key: str | None = None, label: TextType | None = None, @@ -1604,13 +1639,13 @@ def add_columns(self, *labels: TextType) -> list[ColumnKey]: the `add_column` method docstring for more information on how these keys are used. """ - column_keys = [] + column_keys: list[ColumnKey] = [] for label in labels: column_key = self.add_column(label, width=None) column_keys.append(column_key) return column_keys - def add_rows(self, rows: Iterable[Iterable[CellType]]) -> list[RowKey]: + def add_rows(self, rows: Iterable[Iterable[object]]) -> list[RowKey]: """Add a number of rows at the bottom of the DataTable. Args: @@ -1621,7 +1656,7 @@ def add_rows(self, rows: Iterable[Iterable[CellType]]) -> list[RowKey]: the `add_row` method docstring for more information on how these keys are used. """ - row_keys = [] + row_keys: list[RowKey] = [] for row in rows: row_key = self.add_row(*row) row_keys.append(row_key) @@ -1643,7 +1678,7 @@ def remove_row(self, row_key: RowKey | str) -> None: self.check_idle() index_to_delete = self._row_locations.get(row_key) - new_row_locations = TwoWayDict({}) + new_row_locations: TwoWayDict[RowKey, int] = TwoWayDict({}) for row_location_key in self._row_locations: row_index = self._row_locations.get(row_location_key) if row_index > index_to_delete: @@ -2384,7 +2419,7 @@ def sort( The `DataTable` instance. """ - def key_wrapper(row: tuple[RowKey, dict[ColumnKey | str, CellType]]) -> Any: + def key_wrapper(row: tuple[RowKey, dict[ColumnKey, object]]) -> Any: _, row_data = row if columns: result = itemgetter(*columns)(row_data) @@ -2394,12 +2429,13 @@ def key_wrapper(row: tuple[RowKey, dict[ColumnKey | str, CellType]]) -> Any: return key(result) return result - ordered_rows = sorted( - self._data.items(), + items: list[tuple[RowKey, dict[ColumnKey, object]]] = list(self._data.items()) + ordered_rows: list[tuple[RowKey, dict[ColumnKey, object]]] = sorted( + items, key=key_wrapper, reverse=reverse, ) - self._row_locations = TwoWayDict( + self._row_locations: TwoWayDict[RowKey, int] = TwoWayDict( {row_key: new_index for new_index, (row_key, _) in enumerate(ordered_rows)} ) self._update_count += 1