Skip to content

Commit

Permalink
Merge pull request #3447 from Textualize/query-overloads-fix
Browse files Browse the repository at this point in the history
Improve typing for queries.
  • Loading branch information
rodrigogiraoserrao authored Oct 3, 2023
2 parents 02fe3bf + c9629b5 commit 03e3a69
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions src/textual/css/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class WrongType(QueryError):


QueryType = TypeVar("QueryType", bound="Widget")
"""Type variable used to type generic queries."""
ExpectType = TypeVar("ExpectType")
"""Type variable used to further restrict queries."""


@rich.repr.auto(angular=True)
Expand Down Expand Up @@ -187,10 +190,8 @@ def exclude(self, selector: str) -> DOMQuery[QueryType]:
"""
return DOMQuery(self.node, exclude=selector, parent=self)

ExpectType = TypeVar("ExpectType")

@overload
def first(self) -> Widget:
def first(self) -> QueryType:
...

@overload
Expand Down Expand Up @@ -226,7 +227,7 @@ def first(
raise NoMatches(f"No nodes match {self!r}")

@overload
def only_one(self) -> Widget:
def only_one(self) -> QueryType:
...

@overload
Expand All @@ -235,7 +236,7 @@ def only_one(self, expect_type: type[ExpectType]) -> ExpectType:

def only_one(
self, expect_type: type[ExpectType] | None = None
) -> Widget | ExpectType:
) -> QueryType | ExpectType:
"""Get the *only* matching node.
Args:
Expand All @@ -253,7 +254,9 @@ def only_one(
_rich_traceback_omit = True
# Call on first to get the first item. Here we'll use all of the
# testing and checking it provides.
the_one = self.first(expect_type) if expect_type is not None else self.first()
the_one: ExpectType | QueryType = (
self.first(expect_type) if expect_type is not None else self.first()
)
try:
# Now see if we can access a subsequent item in the nodes. There
# should *not* be anything there, so we *should* get an
Expand All @@ -268,10 +271,10 @@ def only_one(
# The IndexError was got, that's a good thing in this case. So
# we return what we found.
pass
return cast("Widget", the_one)
return the_one

@overload
def last(self) -> Widget:
def last(self) -> QueryType:
...

@overload
Expand Down Expand Up @@ -304,7 +307,7 @@ def last(
return last

@overload
def results(self) -> Iterator[Widget]:
def results(self) -> Iterator[QueryType]:
...

@overload
Expand All @@ -313,7 +316,7 @@ def results(self, filter_type: type[ExpectType]) -> Iterator[ExpectType]:

def results(
self, filter_type: type[ExpectType] | None = None
) -> Iterator[Widget | ExpectType]:
) -> Iterator[QueryType | ExpectType]:
"""Get query results, optionally filtered by a given type.
Args:
Expand Down

0 comments on commit 03e3a69

Please sign in to comment.