Skip to content

Commit

Permalink
Improve typing for queries.
Browse files Browse the repository at this point in the history
ExpectType grew its Widget bound and then it was moved out of the body of the class so that it could be referenced inside methods, because it was needed inside the body of 'only_one'.
  • Loading branch information
rodrigogiraoserrao committed Oct 3, 2023
1 parent f4f83ee commit afd5ec1
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions src/textual/css/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ class WrongType(QueryError):
"""Query result was not of the correct type."""


QueryType = TypeVar("QueryType", bound="Widget")
QueryType = TypeVar("QueryType", bound=Widget)
"""Type variable used to type generic queries."""
ExpectType = TypeVar("ExpectType", bound=Widget)
"""Type variable used to further restrict queries."""


@rich.repr.auto(angular=True)
Expand Down Expand Up @@ -187,10 +190,8 @@ def exclude(self, selector: str) -> DOMQuery[QueryType]:
"""
return DOMQuery(self.node, exclude=selector, parent=self)

ExpectType = TypeVar("ExpectType")

@overload
def first(self) -> Widget:
def first(self) -> QueryType:
...

@overload
Expand Down Expand Up @@ -226,7 +227,7 @@ def first(
raise NoMatches(f"No nodes match {self!r}")

@overload
def only_one(self) -> Widget:
def only_one(self) -> QueryType:
...

@overload
Expand All @@ -235,7 +236,7 @@ def only_one(self, expect_type: type[ExpectType]) -> ExpectType:

def only_one(
self, expect_type: type[ExpectType] | None = None
) -> Widget | ExpectType:
) -> QueryType | ExpectType:
"""Get the *only* matching node.
Args:
Expand All @@ -253,7 +254,9 @@ def only_one(
_rich_traceback_omit = True
# Call on first to get the first item. Here we'll use all of the
# testing and checking it provides.
the_one = self.first(expect_type) if expect_type is not None else self.first()
the_one: ExpectType | QueryType = (
self.first(expect_type) if expect_type is not None else self.first()
)
try:
# Now see if we can access a subsequent item in the nodes. There
# should *not* be anything there, so we *should* get an
Expand All @@ -268,10 +271,10 @@ def only_one(
# The IndexError was got, that's a good thing in this case. So
# we return what we found.
pass
return cast("Widget", the_one)
return the_one

@overload
def last(self) -> Widget:
def last(self) -> QueryType:
...

@overload
Expand Down Expand Up @@ -304,7 +307,7 @@ def last(
return last

@overload
def results(self) -> Iterator[Widget]:
def results(self) -> Iterator[QueryType]:
...

@overload
Expand All @@ -313,7 +316,7 @@ def results(self, filter_type: type[ExpectType]) -> Iterator[ExpectType]:

def results(
self, filter_type: type[ExpectType] | None = None
) -> Iterator[Widget | ExpectType]:
) -> Iterator[QueryType | ExpectType]:
"""Get query results, optionally filtered by a given type.
Args:
Expand Down

0 comments on commit afd5ec1

Please sign in to comment.