diff --git a/src/textual/css/query.py b/src/textual/css/query.py index ce966d6b18..4326bb7665 100644 --- a/src/textual/css/query.py +++ b/src/textual/css/query.py @@ -48,7 +48,10 @@ class WrongType(QueryError): """Query result was not of the correct type.""" -QueryType = TypeVar("QueryType", bound="Widget") +QueryType = TypeVar("QueryType", bound=Widget) +"""Type variable used to type generic queries.""" +ExpectType = TypeVar("ExpectType", bound=Widget) +"""Type variable used to further restrict queries.""" @rich.repr.auto(angular=True) @@ -187,10 +190,8 @@ def exclude(self, selector: str) -> DOMQuery[QueryType]: """ return DOMQuery(self.node, exclude=selector, parent=self) - ExpectType = TypeVar("ExpectType") - @overload - def first(self) -> Widget: + def first(self) -> QueryType: ... @overload @@ -226,7 +227,7 @@ def first( raise NoMatches(f"No nodes match {self!r}") @overload - def only_one(self) -> Widget: + def only_one(self) -> QueryType: ... @overload @@ -235,7 +236,7 @@ def only_one(self, expect_type: type[ExpectType]) -> ExpectType: def only_one( self, expect_type: type[ExpectType] | None = None - ) -> Widget | ExpectType: + ) -> QueryType | ExpectType: """Get the *only* matching node. Args: @@ -253,7 +254,9 @@ def only_one( _rich_traceback_omit = True # Call on first to get the first item. Here we'll use all of the # testing and checking it provides. - the_one = self.first(expect_type) if expect_type is not None else self.first() + the_one: ExpectType | QueryType = ( + self.first(expect_type) if expect_type is not None else self.first() + ) try: # Now see if we can access a subsequent item in the nodes. There # should *not* be anything there, so we *should* get an @@ -268,10 +271,10 @@ def only_one( # The IndexError was got, that's a good thing in this case. So # we return what we found. pass - return cast("Widget", the_one) + return the_one @overload - def last(self) -> Widget: + def last(self) -> QueryType: ... @overload @@ -304,7 +307,7 @@ def last( return last @overload - def results(self) -> Iterator[Widget]: + def results(self) -> Iterator[QueryType]: ... @overload @@ -313,7 +316,7 @@ def results(self, filter_type: type[ExpectType]) -> Iterator[ExpectType]: def results( self, filter_type: type[ExpectType] | None = None - ) -> Iterator[Widget | ExpectType]: + ) -> Iterator[QueryType | ExpectType]: """Get query results, optionally filtered by a given type. Args: