Skip to content

Commit

Permalink
day23: tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-ong committed Dec 24, 2023
1 parent bfc018a commit 47a59de
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 92 deletions.
63 changes: 48 additions & 15 deletions day23/lib/classes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from dataclasses import dataclass
from queue import Queue
from typing import Optional
Expand Down Expand Up @@ -78,6 +79,14 @@ def __len__(self) -> int:
"""length of path. Has to be -1 because problem is like that"""
return len(self.route) - 1

def __eq__(self, other: object) -> bool:
if not isinstance(other, Path):
return False
return self.route == other.route

def __hash__(self) -> int:
return hash(",".join(str(item) for item in self.route))


class Maze:
grid: list[list[str]] # 2d array of chars
Expand Down Expand Up @@ -117,6 +126,21 @@ def is_oob(self, position: Position) -> bool:
or position.col >= self.num_cols
)

def copy(self) -> "Maze":
result = Maze(deepcopy(self.grid))
return result

def get_cell_branches(self, position: Position) -> int:
"""returns how many branches come out of this tile"""
result = 0
if self[position] != ".":
return 0
for direction in position.expand():
tile = self[direction]
if tile is not None and tile != "#":
result += 1
return result


class Solver1:
maze: Maze
Expand Down Expand Up @@ -179,19 +203,28 @@ def expand_path(self, path: Path) -> list[Path]:
and expansion_tile != "#"
):
valid_expansions.append(expansion)
return generate_paths(path, valid_expansions)


def generate_paths(path: Path, expansions: list[Position]) -> list[Path]:
"""
Given a path and valid expansions, (optionally) copies the path.
Returns a list of new paths. If there is only one expansion, modifies it
in-place
"""

if len(expansions) == 0:
return []
elif len(expansions) == 1:
path.add(expansions[0])
return [path]
else:
result = []
for expansion in expansions[1:]:
new_path = path.copy()
new_path.add(expansion)
result.append(new_path)
path.add(expansions[0])
result.append(path)

if len(valid_expansions) == 0:
return []
elif len(valid_expansions) == 1:
path.add(valid_expansions[0])
return [path]
else:
result = []
for expansion in valid_expansions[1:]:
new_path = path.copy()
new_path.add(expansion)
result.append(new_path)
path.add(valid_expansions[0])
result.append(path)

return result
return result
134 changes: 63 additions & 71 deletions day23/lib/classes2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import colorama

from day23.lib import classes
from day23.lib.classes import Maze, Path, Position

colorama.init(convert=True)
Expand Down Expand Up @@ -45,45 +46,47 @@ def __len__(self) -> int:


class Solver2:
maze: Maze
input_maze: Maze

def __init__(self, maze: Maze) -> None:
self.maze = maze

def get_cell_branches(self, position: Position) -> int:
result = 0
if self.maze[position] != ".":
return 0
for direction in position.expand():
tile = self.maze[direction]
if tile is not None and tile != "#":
result += 1
return result

def get_nodes(self) -> dict[Position, Node]:
self.input_maze = maze

@staticmethod
def get_nodes(maze: Maze) -> dict[Position, Node]:
"""
Gets does and marks them on the given maze
Note that the maze is modified in-place!
Nodes are *not* populated with edges
"""
nodes: list[Node] = []

start = Position(0, 1)
nodes.append(Node(0, start))
name = 1
for row in range(self.maze.num_rows):
for col in range(self.maze.num_cols):
for row in range(maze.num_rows):
for col in range(maze.num_cols):
pos = Position(row, col)
if self.get_cell_branches(pos) > 2:
if maze.get_cell_branches(pos) > 2:
node = Node(name, pos)
name += 1
nodes.append(node)

# add start and end coz they are dumb

end = Position(self.maze.num_rows - 1, self.maze.num_cols - 2)

end = Position(maze.num_rows - 1, maze.num_cols - 2)
nodes.append(Node(name, end))

for node in nodes:
self.maze[node.position] = colorama.Back.GREEN + "X" + colorama.Back.BLACK
maze[node.position] = colorama.Back.GREEN + "X" + colorama.Back.BLACK
return {node.position: node for node in nodes}

def fill_node(self, start_node: Node, nodes: dict[Position, Node]) -> None:
@staticmethod
def calculate_edges(
start_node: Node, nodes: dict[Position, Node], maze: Maze
) -> None:
"""
Given a start Node and maze, modifies the maze inplace, filling it in with #
Modifies the node and its connecting nodes by adding Edges
"""
first_path = Path()
first_path.add(start_node.position)
paths: Queue[Path] = Queue()
Expand All @@ -98,47 +101,36 @@ def fill_node(self, start_node: Node, nodes: dict[Position, Node]) -> None:
end_node = nodes[pos]
end_node.edges.append(edge.flip())
continue
expansions = self.expand_path(path)
expansions = Solver2.expand_path(path, maze)
for path in expansions:
paths.put(path)

def expand_path(self, path: Path) -> list[Path]:
@staticmethod
def expand_path(path: Path, maze: Maze) -> list[Path]:
"""Expands a path, nuking that section of the maze using #"""
current_pos: Position = path.last()
expansions = current_pos.expand()

valid_expansions = []
for expansion in expansions:
expansion_tile = self.maze[expansion]
expansion_tile = maze[expansion]
if (
path.can_add(expansion)
and expansion_tile is not None
and expansion_tile != "#"
):
valid_expansions.append(expansion)
if expansion_tile == ".":
self.maze[expansion] = "#"

if len(valid_expansions) == 0:
return []
elif len(valid_expansions) == 1:
path.add(valid_expansions[0])
return [path]
else:
result = []
for expansion in valid_expansions[1:]:
new_path = path.copy()
new_path.add(expansion)
result.append(new_path)
path.add(valid_expansions[0])
result.append(path)

return result
maze[expansion] = "#"
return classes.generate_paths(path, valid_expansions)

def build_nodes(self) -> list[Node]:
nodes: dict[Position, Node] = self.get_nodes()
print(self.maze)
# make backup of maze
maze_copy = self.input_maze.copy()
nodes: dict[Position, Node] = self.get_nodes(maze_copy)
print(maze_copy)
for node in nodes.values():
self.fill_node(node, nodes)
self.calculate_edges(node, nodes, maze_copy)

return list(nodes.values())

Expand All @@ -149,34 +141,34 @@ def solve(self) -> int:
start = time.time()
cpu_count = os.cpu_count() or 2
levels = int(math.log(cpu_count, 2))
result = solve3(nodes, 0, len(nodes) - 1, 0, set(), levels)
print(time.time() - start)
result = solve2(nodes, 0, len(nodes) - 1, 0, set(), levels)
print(f"Executed in: {time.time() - start}")
return result


def solve3(
def solve2(
nodes: list[Node],
current: int,
destination: int,
distance: int,
seen: set[int],
forks_remaining: int,
) -> int:
"""Solves a dfs by creating forking into multiprocessing"""
if current == destination:
return distance

best = 0
seen.add(current)

# Check if the current process is the main process

# run the code in this thread
if forks_remaining == 0 or len(nodes[current].edges) == 1:
for edge in nodes[current].edges:
neighbor, weight = edge.node2, edge.length
if neighbor in seen:
continue

result = solve3(
result = solve2(
nodes,
neighbor,
destination,
Expand All @@ -186,30 +178,30 @@ def solve3(
)
best = max(best, result)
else: # Use multiprocessing.Pool for parallel execution
with Pool(len(nodes[current].edges)) as pool:
tasks = []
for edge in nodes[current].edges:
neighbor, weight = edge.node2, edge.length
if neighbor in seen:
continue
tasks.append(
[
nodes,
neighbor,
destination,
distance + weight,
seen.copy(),
forks_remaining - 1,
]
)
for result in pool.map(solve3_helper, tasks):
tasks = []
for edge in nodes[current].edges:
neighbor, weight = edge.node2, edge.length
if neighbor in seen:
continue
tasks.append(
[
nodes,
neighbor,
destination,
distance + weight,
seen,
forks_remaining - 1,
]
)
with Pool(len(tasks)) as pool:
for result in pool.map(solve2_helper, tasks):
best = max(best, result)

seen.remove(current)

return best


# ThreadPoolExecutor doesnt have starmap so we use a helper
def solve3_helper(args: list[Any]) -> int:
return solve3(*args)
def solve2_helper(args: list[Any]) -> int:
"""ThreadPoolExecutor doesnt have starmap so we use a helper"""
return solve2(*args)
67 changes: 66 additions & 1 deletion day23/tests/test_classes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from day23.day23 import INPUT_SMALL
from day23.lib.classes import Maze, Path, Position
from day23.lib.classes import Maze, Path, Position, Solver1, generate_paths
from day23.lib.parsers import get_maze


Expand All @@ -19,6 +19,37 @@ def test_maze() -> None:
assert maze[Position(0, 1)] == "."
assert maze[Position(-1, 0)] is None

position_checks = [
(Position(0, 1), 1),
(Position(1, 1), 2),
(Position(3, 11), 3),
(Position(5, 3), 3),
]

for pos, result in position_checks:
print(pos, result)
assert maze.get_cell_branches(pos) == result


def test_solver1() -> None:
maze: Maze = get_maze(INPUT_SMALL)
solver: Solver1 = Solver1(maze)

expands = [
(
Position(5, 5),
" ",
{Position(4, 5), Position(6, 5), Position(5, 4), Position(5, 6)},
),
(Position(5, 5), "^", {Position(4, 5)}),
(Position(5, 5), ">", {Position(5, 6)}),
(Position(5, 5), "<", {Position(5, 4)}),
(Position(5, 5), "v", {Position(6, 5)}),
]

for pos, tile, result in expands:
assert set(solver.expand_hill(pos, tile)) == result


def test_path() -> None:
path = Path()
Expand All @@ -32,3 +63,37 @@ def test_path() -> None:
path2.add(Position(0, 2))
assert path.last() == Position(0, 1)
assert path2.last() == Position(0, 2)


def test_generate_paths() -> None:
path = Path()
path.add(Position(0, 0))
path.add(Position(0, 1))

# manually generate the paths
path1 = path.copy()
path1.add(Position(0, 2))
path2 = path.copy()
path2.add(Position(-1, 1))

# note that the original path is modified in-place!
paths: list[Path] = generate_paths(path, [Position(0, 2), Position(-1, 1)])
assert len(paths) == 2

auto_path1 = paths[0]
auto_path2 = paths[1]
assert {auto_path1, auto_path2} == {path1, path2}

paths = generate_paths(path, [])
assert (len(paths)) == 0

# test that we modify the path inplace when passed one position
path_before = path # reference
len_before = len(path_before.route)
paths = generate_paths(path, [Position(69, 69)])
assert (
len(paths) == 1
and paths[0] == path_before
and len(path_before.route) == len(paths[0].route)
and len_before + 1 == len(path_before.route)
)
Loading

0 comments on commit 47a59de

Please sign in to comment.