Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add visit() function to source tree #121

Merged
merged 2 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion codebasin/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import hashlib
import logging
import os
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from copy import copy
from enum import Enum
from typing import Self

import numpy as np
Expand Down Expand Up @@ -539,6 +540,11 @@ class ParseError(ValueError):
"""


class Visit(Enum):
NEXT = 0
NEXT_SIBLING = 1


class Node:
"""
Base class for all other Node types.
Expand Down Expand Up @@ -597,6 +603,22 @@ def walk(self) -> Iterable[Self]:
for child in self.children:
yield from child.walk()

def visit(self, visitor: Callable[[Self], Visit]):
"""
Visit all descendants of this node via a preorder traversal, using the
supplied visitor.

Raises
------
TypeError
If `visitor` is not callable.
"""
if not callable(visitor):
raise TypeError("visitor is not callable.")
if visitor(self) != Visit.NEXT_SIBLING:
for child in self.children:
child.visit(visitor)


class FileNode(Node):
"""
Expand Down Expand Up @@ -2359,6 +2381,18 @@ def walk(self) -> Iterable[Node]:
"""
yield from self.root.walk()

def visit(self, visitor: Callable[[Node], Visit]):
"""
Visit each node in the tree via a preorder traversal, using the
supplied visitor.

Raises
------
TypeError
If `visitor` is not callable.
"""
self.root.visit(visitor)

def associate_file(self, filename):
self.root.filename = filename

Expand Down
123 changes: 89 additions & 34 deletions tests/source-tree/test_source_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings

from codebasin.file_parser import FileParser
from codebasin.preprocessor import CodeNode, DirectiveNode, FileNode
from codebasin.preprocessor import CodeNode, DirectiveNode, FileNode, Visit


class TestSourceTree(unittest.TestCase):
Expand All @@ -19,9 +19,6 @@ def setUp(self):
logging.getLogger("codebasin").disabled = False
warnings.simplefilter("ignore", ResourceWarning)

def test_walk(self):
"""Check that walk() visits nodes in the expected order"""

# TODO: Revisit this when SourceTree can be built without a file.
with tempfile.NamedTemporaryFile(
mode="w",
Expand All @@ -43,36 +40,94 @@ def test_walk(self):
f.close()

# TODO: Revisit this when __str__() is more reliable.
tree = FileParser(f.name).parse_file(summarize_only=False)
expected_types = [
FileNode,
DirectiveNode,
CodeNode,
DirectiveNode,
CodeNode,
DirectiveNode,
CodeNode,
DirectiveNode,
CodeNode,
]
expected_contents = [
f.name,
"FOO",
"foo",
"BAR",
"bar",
"else",
"baz",
"endif",
"qux",
]
for i, node in enumerate(tree.walk()):
self.assertTrue(isinstance(node, expected_types[i]))
if isinstance(node, CodeNode):
contents = node.spelling()[0]
else:
contents = str(node)
self.assertTrue(expected_contents[i] in contents)
self.tree = FileParser(f.name).parse_file(summarize_only=False)
self.filename = f.name

def test_walk(self):
"""Check that walk() visits nodes in the expected order"""
expected_types = [
FileNode,
DirectiveNode,
CodeNode,
DirectiveNode,
CodeNode,
DirectiveNode,
CodeNode,
DirectiveNode,
CodeNode,
]
expected_contents = [
self.filename,
"FOO",
"foo",
"BAR",
"bar",
"else",
"baz",
"endif",
"qux",
]
for i, node in enumerate(self.tree.walk()):
self.assertTrue(isinstance(node, expected_types[i]))
if isinstance(node, CodeNode):
contents = node.spelling()[0]
else:
contents = str(node)
self.assertTrue(expected_contents[i] in contents)

def test_visit_types(self):
"""Check that visit() validates inputs"""

class valid_visitor:
def __call__(self, node):
return True

self.tree.visit(valid_visitor())

def visitor_function(node):
return True

self.tree.visit(visitor_function)

with self.assertRaises(TypeError):
self.tree.visit(1)

class invalid_visitor:
pass

with self.assertRaises(TypeError):
self.tree.visit(invalid_visitor())

def test_visit(self):
"""Check that visit() visits nodes as expected"""

# Check that a trivial visitor visits all nodes.
class NodeCounter:
def __init__(self):
self.count = 0

def __call__(self, node):
self.count += 1

node_counter = NodeCounter()
self.tree.visit(node_counter)
self.assertEqual(node_counter.count, 9)

# Check that returning NEXT_SIBLING prevents descent.
class TopLevelCounter:
def __init__(self):
self.count = 0

def __call__(self, node):
if not isinstance(node, FileNode):
self.count += 1
if isinstance(node, DirectiveNode):
return Visit.NEXT_SIBLING
return Visit.NEXT

top_level_counter = TopLevelCounter()
self.tree.visit(top_level_counter)
self.assertEqual(top_level_counter.count, 5)


if __name__ == "__main__":
Expand Down
Loading