Skip to content

Commit

Permalink
Fixing some DataTable typing issues (this has propagated to result in…
Browse files Browse the repository at this point in the history
… more issues, but its a step in the right direction)
  • Loading branch information
darrenburns committed Feb 28, 2024
1 parent 2a8d6b6 commit ee088ab
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 34 deletions.
4 changes: 2 additions & 2 deletions src/textual/_two_way_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __delitem__(self, key: Key) -> None:
def __iter__(self):
return iter(self._forward)

def get(self, key: Key) -> Value:
def get(self, key: Key) -> Value | None:
"""Given a key, efficiently lookup and return the associated value.
Args:
Expand All @@ -43,7 +43,7 @@ def get(self, key: Key) -> Value:
"""
return self._forward.get(key)

def get_key(self, value: Value) -> Key:
def get_key(self, value: Value) -> Key | None:
"""Given a value, efficiently lookup and return the associated key.
Args:
Expand Down
100 changes: 68 additions & 32 deletions src/textual/widgets/_data_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass
from itertools import chain, zip_longest
from operator import itemgetter
from typing import Any, Callable, ClassVar, Generic, Iterable, NamedTuple, TypeVar, cast
from typing import Any, Callable, ClassVar, Iterable, NamedTuple, cast

import rich.repr
from rich.console import RenderableType
Expand Down Expand Up @@ -43,8 +43,6 @@
)
CursorType = Literal["cell", "row", "column", "none"]
"""The valid types of cursors for [`DataTable.cursor_type`][textual.widgets.DataTable.cursor_type]."""
CellType = TypeVar("CellType")
"""Type used for cells in the DataTable."""

_DEFAULT_CELL_X_PADDING = 1
"""Default padding to use on each side of a column in the data table."""
Expand Down Expand Up @@ -183,7 +181,7 @@ class Column:
content_width: int = 0
auto_width: bool = False

def get_render_width(self, data_table: DataTable[Any]) -> int:
def get_render_width(self, data_table: DataTable) -> int:
"""Width, in cells, required to render the column with padding included.
Args:
Expand Down Expand Up @@ -214,7 +212,7 @@ class RowRenderables(NamedTuple):
cells: list[RenderableType]


class DataTable(ScrollView, Generic[CellType], can_focus=True):
class DataTable(ScrollView, can_focus=True):
"""A tabular widget that contains data."""

BINDINGS: ClassVar[list[BindingType]] = [
Expand Down Expand Up @@ -355,13 +353,13 @@ class CellHighlighted(Message):
def __init__(
self,
data_table: DataTable,
value: CellType,
value: object,
coordinate: Coordinate,
cell_key: CellKey,
) -> None:
self.data_table = data_table
"""The data table."""
self.value: CellType = value
self.value: object = value
"""The value in the highlighted cell."""
self.coordinate: Coordinate = coordinate
"""The coordinate of the highlighted cell."""
Expand Down Expand Up @@ -390,13 +388,13 @@ class CellSelected(Message):
def __init__(
self,
data_table: DataTable,
value: CellType,
value: object,
coordinate: Coordinate,
cell_key: CellKey,
) -> None:
self.data_table = data_table
"""The data table."""
self.value: CellType = value
self.value: object = value
"""The value in the cell that was selected."""
self.coordinate: Coordinate = coordinate
"""The coordinate of the cell that was selected."""
Expand Down Expand Up @@ -641,7 +639,7 @@ def __init__(
"""

super().__init__(name=name, id=id, classes=classes, disabled=disabled)
self._data: dict[RowKey, dict[ColumnKey, CellType]] = {}
self._data: dict[RowKey, dict[ColumnKey, object]] = {}
"""Contains the cells of the table, indexed by row key and column key.
The final positioning of a cell on screen cannot be determined solely by this
structure. Instead, we must check _row_locations and _column_locations to find
Expand Down Expand Up @@ -776,7 +774,7 @@ def update_cell(
self,
row_key: RowKey | str,
column_key: ColumnKey | str,
value: CellType,
value: object,
*,
update_width: bool = False,
) -> None:
Expand Down Expand Up @@ -817,7 +815,7 @@ def update_cell(
self.refresh()

def update_cell_at(
self, coordinate: Coordinate, value: CellType, *, update_width: bool = False
self, coordinate: Coordinate, value: object, *, update_width: bool = False
) -> None:
"""Update the content inside the cell currently occupying the given coordinate.
Expand All @@ -833,7 +831,7 @@ def update_cell_at(
row_key, column_key = self.coordinate_to_cell_key(coordinate)
self.update_cell(row_key, column_key, value, update_width=update_width)

def get_cell(self, row_key: RowKey | str, column_key: ColumnKey | str) -> CellType:
def get_cell(self, row_key: RowKey | str, column_key: ColumnKey | str) -> object:
"""Given a row key and column key, return the value of the corresponding cell.
Args:
Expand All @@ -843,6 +841,8 @@ def get_cell(self, row_key: RowKey | str, column_key: ColumnKey | str) -> CellTy
Returns:
The value of the cell identified by the row and column keys.
"""
row_key, column_key = self._ensure_keys(row_key, column_key)

try:
cell_value = self._data[row_key][column_key]
except KeyError:
Expand All @@ -851,7 +851,7 @@ def get_cell(self, row_key: RowKey | str, column_key: ColumnKey | str) -> CellTy
)
return cell_value

def get_cell_at(self, coordinate: Coordinate) -> CellType:
def get_cell_at(self, coordinate: Coordinate) -> object:
"""Get the value from the cell occupying the given coordinate.
Args:
Expand Down Expand Up @@ -881,6 +881,8 @@ def get_cell_coordinate(
Raises:
CellDoesNotExist: If the specified cell does not exist.
"""
row_key, column_key = self._ensure_keys(row_key, column_key)

if (
row_key not in self._row_locations
or column_key not in self._column_locations
Expand All @@ -892,7 +894,7 @@ def get_cell_coordinate(
column_index = self._column_locations.get(column_key)
return Coordinate(row_index, column_index)

def get_row(self, row_key: RowKey | str) -> list[CellType]:
def get_row(self, row_key: RowKey | str) -> list[object]:
"""Get the values from the row identified by the given row key.
Args:
Expand All @@ -904,15 +906,21 @@ def get_row(self, row_key: RowKey | str) -> list[CellType]:
Raises:
RowDoesNotExist: When there is no row corresponding to the key.
"""
if isinstance(row_key, str):
row_key = RowKey(row_key)

if row_key not in self._row_locations:
raise RowDoesNotExist(f"Row key {row_key!r} is not valid.")
cell_mapping: dict[ColumnKey, CellType] = self._data.get(row_key, {})
ordered_row: list[CellType] = [
cell_mapping[column.key] for column in self.ordered_columns
]
cell_mapping: dict[ColumnKey, object] | None = self._data.get(row_key)
if cell_mapping is not None:
ordered_row: list[object] = [
cell_mapping[column.key] for column in self.ordered_columns
]
else:
ordered_row = []
return ordered_row

def get_row_at(self, row_index: int) -> list[CellType]:
def get_row_at(self, row_index: int) -> list[object]:
"""Get the values from the cells in a row at a given index. This will
return the values from a row based on the rows _current position_ in
the table.
Expand Down Expand Up @@ -943,11 +951,15 @@ def get_row_index(self, row_key: RowKey | str) -> int:
Raises:
RowDoesNotExist: If the row key does not exist.
"""
if isinstance(row_key, str):
row_key = RowKey(row_key)

if row_key not in self._row_locations:
raise RowDoesNotExist(f"No row exists for row_key={row_key!r}")

return self._row_locations.get(row_key)

def get_column(self, column_key: ColumnKey | str) -> Iterable[CellType]:
def get_column(self, column_key: ColumnKey | str) -> Iterable[object]:
"""Get the values from the column identified by the given column key.
Args:
Expand All @@ -959,6 +971,9 @@ def get_column(self, column_key: ColumnKey | str) -> Iterable[CellType]:
Raises:
ColumnDoesNotExist: If there is no column corresponding to the key.
"""
if isinstance(column_key, str):
column_key = ColumnKey(column_key)

if column_key not in self._column_locations:
raise ColumnDoesNotExist(f"Column key {column_key!r} is not valid.")

Expand All @@ -967,7 +982,7 @@ def get_column(self, column_key: ColumnKey | str) -> Iterable[CellType]:
row_key = row_metadata.key
yield data[row_key][column_key]

def get_column_at(self, column_index: int) -> Iterable[CellType]:
def get_column_at(self, column_index: int) -> Iterable[object]:
"""Get the values from the column at a given index.
Args:
Expand Down Expand Up @@ -997,10 +1012,24 @@ def get_column_index(self, column_key: ColumnKey | str) -> int:
Raises:
ColumnDoesNotExist: If the column key does not exist.
"""
if isinstance(column_key, str):
column_key = ColumnKey(column_key)

if column_key not in self._column_locations:
raise ColumnDoesNotExist(f"No column exists for column_key={column_key!r}")

return self._column_locations.get(column_key)

def _ensure_keys(
self, row_key: RowKey | str, column_key: ColumnKey | str
) -> tuple[RowKey, ColumnKey]:
"""Convert row/column keys which may be strings into the dedicated RowKey/ColumnKey objects."""
if isinstance(row_key, str):
row_key = RowKey(row_key)
if isinstance(column_key, str):
column_key = ColumnKey(column_key)
return row_key, column_key

def _clear_caches(self) -> None:
self._row_render_cache.clear()
self._cell_render_cache.clear()
Expand Down Expand Up @@ -1175,9 +1204,15 @@ def coordinate_to_cell_key(self, coordinate: Coordinate) -> CellKey:
"""
if not self.is_valid_coordinate(coordinate):
raise CellDoesNotExist(f"No cell exists at {coordinate!r}.")

row_index, column_index = coordinate
row_key = self._row_locations.get_key(row_index)
column_key = self._column_locations.get_key(column_index)

# We've checked the coordinate is valid via is_valid_coordinate, so row_key and column_key should exist.
assert row_key is not None
assert column_key is not None

return CellKey(row_key, column_key)

def _highlight_row(self, row_index: int) -> None:
Expand Down Expand Up @@ -1478,7 +1513,7 @@ def add_column(
*,
width: int | None = None,
key: str | None = None,
default: CellType | None = None,
default: object = None,
) -> ColumnKey:
"""Add a column to the table.
Expand Down Expand Up @@ -1531,7 +1566,7 @@ def add_column(

def add_row(
self,
*cells: CellType,
*cells: object,
height: int | None = 1,
key: str | None = None,
label: TextType | None = None,
Expand Down Expand Up @@ -1604,13 +1639,13 @@ def add_columns(self, *labels: TextType) -> list[ColumnKey]:
the `add_column` method docstring for more information on how
these keys are used.
"""
column_keys = []
column_keys: list[ColumnKey] = []
for label in labels:
column_key = self.add_column(label, width=None)
column_keys.append(column_key)
return column_keys

def add_rows(self, rows: Iterable[Iterable[CellType]]) -> list[RowKey]:
def add_rows(self, rows: Iterable[Iterable[object]]) -> list[RowKey]:
"""Add a number of rows at the bottom of the DataTable.
Args:
Expand All @@ -1621,7 +1656,7 @@ def add_rows(self, rows: Iterable[Iterable[CellType]]) -> list[RowKey]:
the `add_row` method docstring for more information on how
these keys are used.
"""
row_keys = []
row_keys: list[RowKey] = []
for row in rows:
row_key = self.add_row(*row)
row_keys.append(row_key)
Expand All @@ -1643,7 +1678,7 @@ def remove_row(self, row_key: RowKey | str) -> None:
self.check_idle()

index_to_delete = self._row_locations.get(row_key)
new_row_locations = TwoWayDict({})
new_row_locations: TwoWayDict[RowKey, int] = TwoWayDict({})
for row_location_key in self._row_locations:
row_index = self._row_locations.get(row_location_key)
if row_index > index_to_delete:
Expand Down Expand Up @@ -2384,7 +2419,7 @@ def sort(
The `DataTable` instance.
"""

def key_wrapper(row: tuple[RowKey, dict[ColumnKey | str, CellType]]) -> Any:
def key_wrapper(row: tuple[RowKey, dict[ColumnKey, object]]) -> Any:
_, row_data = row
if columns:
result = itemgetter(*columns)(row_data)
Expand All @@ -2394,12 +2429,13 @@ def key_wrapper(row: tuple[RowKey, dict[ColumnKey | str, CellType]]) -> Any:
return key(result)
return result

ordered_rows = sorted(
self._data.items(),
items: list[tuple[RowKey, dict[ColumnKey, object]]] = list(self._data.items())
ordered_rows: list[tuple[RowKey, dict[ColumnKey, object]]] = sorted(
items,
key=key_wrapper,
reverse=reverse,
)
self._row_locations = TwoWayDict(
self._row_locations: TwoWayDict[RowKey, int] = TwoWayDict(
{row_key: new_index for new_index, (row_key, _) in enumerate(ordered_rows)}
)
self._update_count += 1
Expand Down

0 comments on commit ee088ab

Please sign in to comment.