diff --git a/src/textual/walk.py b/src/textual/walk.py index 0f6790c882..dcda856e49 100644 --- a/src/textual/walk.py +++ b/src/textual/walk.py @@ -50,29 +50,37 @@ def walk_depth_first( Args: root: The root note (starting point). - filter_type: Optional DOMNode subclass to filter by, or ``None`` for no filter. + filter_type: Optional DOMNode subclass to filter by, or `None` for no filter. with_root: Include the root in the walk. Returns: - An iterable of DOMNodes, or the type specified in ``filter_type``. + An iterable of DOMNodes, or the type specified in `filter_type`. """ - from textual.dom import DOMNode - stack: list[Iterator[DOMNode]] = [iter(root.children)] pop = stack.pop push = stack.append - check_type = filter_type or DOMNode - if with_root and isinstance(root, check_type): - yield root - while stack: - if (node := next(stack[-1], None)) is None: - pop() - else: - if isinstance(node, check_type): + if filter_type is None: + if with_root: + yield root + while stack: + if (node := next(stack[-1], None)) is None: + pop() + else: yield node - if children := node._nodes: - push(iter(children)) + if children := node._nodes: + push(iter(children)) + else: + if with_root and isinstance(root, filter_type): + yield root + while stack: + if (node := next(stack[-1], None)) is None: + pop() + else: + if isinstance(node, filter_type): + yield node + if children := node._nodes: + push(iter(children)) if TYPE_CHECKING: @@ -108,11 +116,11 @@ def walk_breadth_first( Args: root: The root note (starting point). - filter_type: Optional DOMNode subclass to filter by, or ``None`` for no filter. + filter_type: Optional DOMNode subclass to filter by, or `None` for no filter. with_root: Include the root in the walk. Returns: - An iterable of DOMNodes, or the type specified in ``filter_type``. + An iterable of DOMNodes, or the type specified in `filter_type`. """ from textual.dom import DOMNode