Skip to content

Commit

Permalink
Merge pull request #27 from shunsvineyard/bug_fix
Browse files Browse the repository at this point in the history
Bug Fix
  • Loading branch information
Shun Huang authored Jul 4, 2021
2 parents 5ff0e6a + 61e7a25 commit c837718
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 51 deletions.
27 changes: 14 additions & 13 deletions forest/binary_trees/double_threaded_binary_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,33 +339,34 @@ def get_predecessor(node: Node) -> Optional[Node]:
return None

@staticmethod
def get_height(node: Node) -> int:
def get_height(node: Optional[Node]) -> int:
"""Get the height of the given subtree.
Parameters
----------
node: `Node`
node: `Optional[Node]`
The root of the subtree to get its height.
Returns
-------
`int`
The height of the given subtree. 0 if the subtree has only one node.
"""
if node.left_thread is False and node.right_thread is False:
return (
max(
DoubleThreadedBinaryTree.get_height(node.left), # type: ignore
DoubleThreadedBinaryTree.get_height(node.right), # type: ignore
if node:
if node.left_thread is False and node.right_thread is False:
return (
max(
DoubleThreadedBinaryTree.get_height(node.left),
DoubleThreadedBinaryTree.get_height(node.right),
)
+ 1
)
+ 1
)

if node.left_thread and node.right_thread is False:
DoubleThreadedBinaryTree.get_height(node.right) + 1 # type: ignore
if node.left_thread and node.right_thread is False:
return DoubleThreadedBinaryTree.get_height(node.right) + 1

if node.right_thread and node.left_thread is False:
DoubleThreadedBinaryTree.get_height(node.left) + 1 # type: ignore
if node.right_thread and node.left_thread is False:
return DoubleThreadedBinaryTree.get_height(node.left) + 1

return 0

Expand Down
55 changes: 28 additions & 27 deletions forest/binary_trees/single_threaded_binary_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,34 +316,34 @@ def get_predecessor(node: Node) -> Optional[Node]:
return parent

@staticmethod
def get_height(node: Node) -> int:
def get_height(node: Optional[Node]) -> int:
"""Get the height of the given subtree.
Parameters
----------
node: `Node`
node: `Optional[Node]`
The root of the subtree to get its height.
Returns
-------
`int`
The height of the given subtree. 0 if the subtree has only one node.
"""
if node.left and node.is_thread is False:
return (
max(
RightThreadedBinaryTree.get_height(node.left),
RightThreadedBinaryTree.get_height(node.right), # type: ignore
if node:
if node.left and node.is_thread is False:
return (
max(
RightThreadedBinaryTree.get_height(node.left),
RightThreadedBinaryTree.get_height(node.right),
)
+ 1
)
+ 1
)

if node.left:
return RightThreadedBinaryTree.get_height(node=node.left) + 1

if node.is_thread is False:
return RightThreadedBinaryTree.get_height(node=node.right) + 1 # type: ignore # noqa: E501
if node.left:
return RightThreadedBinaryTree.get_height(node=node.left) + 1

if node.is_thread is False:
return RightThreadedBinaryTree.get_height(node=node.right) + 1
return 0

def inorder_traverse(self) -> traversal.Pairs:
Expand Down Expand Up @@ -697,33 +697,34 @@ def get_predecessor(node: Node) -> Optional[Node]:
return None

@staticmethod
def get_height(node: Node) -> int:
def get_height(node: Optional[Node]) -> int:
"""Get the height of the given subtree.
Parameters
----------
node: `Node`
node: `Optional[Node]`
The root of the subtree to get its height.
Returns
-------
`int`
The height of the given subtree. 0 if the subtree has only one node.
"""
if node.right and node.is_thread is False:
return (
max(
LeftThreadedBinaryTree.get_height(node.left), # type: ignore
LeftThreadedBinaryTree.get_height(node.right),
if node:
if node.right and node.is_thread is False:
return (
max(
LeftThreadedBinaryTree.get_height(node.left),
LeftThreadedBinaryTree.get_height(node.right),
)
+ 1
)
+ 1
)

if node.right:
return LeftThreadedBinaryTree.get_height(node=node.right) + 1
if node.right:
return LeftThreadedBinaryTree.get_height(node=node.right) + 1

if node.is_thread is False:
return LeftThreadedBinaryTree.get_height(node=node.left) + 1 # type: ignore # noqa: E501
if node.is_thread is False:
return LeftThreadedBinaryTree.get_height(node=node.left) + 1

return 0

Expand Down
25 changes: 14 additions & 11 deletions forest/binary_trees/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,7 @@ def _inorder_traverse_non_recursive(root: SupportedNode) -> Pairs:
stack = []
if root.right:
stack.append(root.right)
stack.append(root)

stack.append(root)
current = root.left

while True:
Expand All @@ -267,7 +266,7 @@ def _inorder_traverse_non_recursive(root: SupportedNode) -> Pairs:
current = current.left
continue
stack.append(current)
current = None
current = current.left

else: # current is None

Expand All @@ -282,7 +281,7 @@ def _inorder_traverse_non_recursive(root: SupportedNode) -> Pairs:
if len(stack) > 0:
if current.right == stack[-1]:
yield (current.key, current.data)
current = None
current = stack.pop() if len(stack) > 0 else None
continue
else: # current.right != stack[-1]:
# This case means there are more nodes on the right
Expand All @@ -307,8 +306,7 @@ def _reverse_inorder_traverse_non_recursive(root: SupportedNode) -> Pairs:
stack = []
if root.left:
stack.append(root.left)
stack.append(root)

stack.append(root)
current = root.right

while True:
Expand All @@ -320,7 +318,7 @@ def _reverse_inorder_traverse_non_recursive(root: SupportedNode) -> Pairs:
current = current.right
continue
stack.append(current)
current = None
current = current.right

else: # current is None

Expand All @@ -335,7 +333,7 @@ def _reverse_inorder_traverse_non_recursive(root: SupportedNode) -> Pairs:
if len(stack) > 0:
if current.left == stack[-1]:
yield (current.key, current.data)
current = None
current = stack.pop() if len(stack) > 0 else None
continue
else: # current.right != stack[-1]:
# This case means there are more nodes on the right
Expand Down Expand Up @@ -385,7 +383,6 @@ def _postorder_traverse_non_recursive(root: SupportedNode) -> Pairs:
stack = []
if root.right:
stack.append(root.right)

stack.append(root)
current = root.left

Expand All @@ -398,8 +395,12 @@ def _postorder_traverse_non_recursive(root: SupportedNode) -> Pairs:
current = current.left
continue
else: # current.right is None
yield (current.key, current.data)
current = None
if current.left:
stack.append(current)
else:
yield (current.key, current.data)

current = current.left

else: # current is None
if len(stack) > 0:
Expand All @@ -421,3 +422,5 @@ def _postorder_traverse_non_recursive(root: SupportedNode) -> Pairs:
else: # stack is empty
yield (current.key, current.data)
break
else: # stack is empty
break
39 changes: 39 additions & 0 deletions tests/test_traversal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Unit tests for the traversal module."""

import random

from forest.binary_trees import binary_search_tree
from forest.binary_trees import traversal

Expand Down Expand Up @@ -122,3 +124,40 @@ def test_binary_search_tree_traversal(basic_tree):
(4, "4"),
(1, "1"),
]


def test_binary_search_tree_traversal_random():
"""Test binary search tree traversal with random sampling."""
for _ in range(0, 10):

insert_data = random.sample(range(1, 2000), 1000)

tree = binary_search_tree.BinarySearchTree()
for key in insert_data:
tree.insert(key=key, data=str(key))

preorder_recursive = [item for item in traversal.preorder_traverse(tree, True)]
preorder = [item for item in traversal.preorder_traverse(tree, False)]
assert preorder_recursive == preorder

inorder_recursive = [item for item in traversal.inorder_traverse(tree, True)]
inorder_nonrecursive = [
item for item in traversal.inorder_traverse(tree, False)
]
assert inorder_recursive == inorder_nonrecursive

rinorder_recursive = [
item for item in traversal.reverse_inorder_traverse(tree, True)
]
rinorder_nonrecursive = [
item for item in traversal.reverse_inorder_traverse(tree, False)
]
assert rinorder_recursive == rinorder_nonrecursive

postorder_recursive = [
item for item in traversal.postorder_traverse(tree, True)
]
postorder_nonrecursive = [
item for item in traversal.postorder_traverse(tree, False)
]
assert postorder_recursive == postorder_nonrecursive

0 comments on commit c837718

Please sign in to comment.