From 1638e8335e295a8b568f5ec5f9310f02d21d291c Mon Sep 17 00:00:00 2001 From: alexo Date: Sun, 24 Dec 2023 03:42:43 +1100 Subject: [PATCH] day23: use dfs to greatly speed up --- day23/day23.py | 18 ++---- day23/lib/classes.py | 20 ++---- day23/lib/classes2.py | 133 +++++++++++--------------------------- day23/tests/test_day23.py | 7 +- 4 files changed, 54 insertions(+), 124 deletions(-) diff --git a/day23/day23.py b/day23/day23.py index b01e120..adf742d 100644 --- a/day23/day23.py +++ b/day23/day23.py @@ -1,4 +1,4 @@ -from day23.lib.classes import BasePath, Maze, Solver, Solver1 +from day23.lib.classes import Maze, Path, Solver1 from day23.lib.classes2 import Solver2 from day23.lib.parsers import get_maze @@ -8,25 +8,21 @@ def part1(maze: Maze) -> int: solver = Solver1(maze, True) - return run_solver(solver) + paths: list[Path] = solver.solve() + path_lengths = [len(path) for path in paths] + path_lengths.sort(reverse=True) + return path_lengths[0] def part2(maze: Maze) -> int: solver = Solver2(maze) - return run_solver(solver) - - -def run_solver(solver: Solver) -> int: - paths: list[BasePath] = solver.solve() - path_lengths = [len(path) for path in paths] - path_lengths.sort(reverse=True) - return path_lengths[0] + return solver.solve() def main() -> None: maze: Maze = get_maze(INPUT) - # print(part1(maze)) + print(part1(maze)) print(part2(maze)) diff --git a/day23/lib/classes.py b/day23/lib/classes.py index cdd3ca0..a293649 100644 --- a/day23/lib/classes.py +++ b/day23/lib/classes.py @@ -35,12 +35,7 @@ def expand(self) -> list["Position"]: ] -class BasePath: - def __len__(self) -> int: - raise NotImplementedError() - - -class Path(BasePath): +class Path: route: list[Position] nodes: set[Position] @@ -123,12 +118,7 @@ def is_oob(self, position: Position) -> bool: ) -class Solver: - def solve(self) -> list[BasePath]: - return [] - - -class Solver1(Solver): +class Solver1: maze: Maze handle_hills: bool @@ -136,13 +126,13 @@ def __init__(self, maze: Maze, handle_hills: bool = True) -> None: self.maze = maze self.handle_hills = handle_hills - def solve(self) -> list[BasePath]: + def solve(self) -> list[Path]: paths: Queue[Path] = Queue() first_path = Path() first_path.add(Position(0, 1)) paths.put(first_path) # bfs all paths simultaneously - results: list[BasePath] = [] + results: list[Path] = [] count = 1 while not paths.empty(): path = paths.get() @@ -154,8 +144,6 @@ def solve(self) -> list[BasePath]: for expansion in expansions: paths.put(expansion) count += 1 - if count % 1000 == 0: - print(count) return results diff --git a/day23/lib/classes2.py b/day23/lib/classes2.py index 6762217..cc7a32e 100644 --- a/day23/lib/classes2.py +++ b/day23/lib/classes2.py @@ -1,12 +1,10 @@ """part 2 solution""" from dataclasses import dataclass, field -from multiprocessing import Pool from queue import Queue import colorama -import tqdm -from day23.lib.classes import BasePath, Maze, Path, Position, Solver +from day23.lib.classes import Maze, Path, Position colorama.init(convert=True) @@ -26,6 +24,10 @@ class Edge: node1: int node2: int path: Path = field(repr=False) + length: int = 0 + + def __post_init__(self) -> None: + self.length = len(self.path) def flip(self) -> "Edge": return Edge(self.node2, self.node1, self.path.flip()) @@ -38,7 +40,7 @@ def __len__(self) -> int: # Same thing as path but with nodes kappa -class NodePath(BasePath): +class NodePath: path: list[int] node_ids: set[int] path_length: int @@ -73,68 +75,7 @@ def __len__(self) -> int: return self.path_length -def expand_node_path(node_path: NodePath, nodes: list[Node]) -> list[NodePath]: - """Expands a node path, giving back a list of of NodePaths""" - last_node: Node = nodes[node_path.last()] - result = [] - for edge in last_node.edges: - target_node_id: int = edge.node2 - if node_path.can_add(target_node_id): - to_add = node_path.copy() - to_add.add(target_node_id, len(edge.path)) - result.append(to_add) - return result - - -def worker_solve( - nodes: list[Node], - paths_to_process: list[NodePath], - break_early: bool, - thread_id: int, -) -> tuple[list[BasePath], list[NodePath]]: - results: list[BasePath] = [] - unfinished_paths: Queue[NodePath] = Queue() - for item in paths_to_process: - unfinished_paths.put(item) - - pbar = tqdm.tqdm( - desc=f"Thread{thread_id}", total=len(paths_to_process), position=thread_id - ) - if break_early: - pbar.total = 10000 - pbar.set_description("Initial run") - - while not unfinished_paths.empty(): - path = unfinished_paths.get() - node_id: int = path.last() - if node_id == nodes[-1].name: # end node - results.append(path) - - if break_early: - pbar.update() - if pbar.n % 10000 == 0: - break - - continue - - expansions = expand_node_path(path, nodes) - - if not break_early: - pbar.total += len(expansions) - pbar.update() - - for p in expansions: - unfinished_paths.put(p) - pbar.close() - return results, list(unfinished_paths.queue) - - -def split_list(items: list[NodePath], num_chunks: int) -> list[list[NodePath]]: - chunk_size = (len(items) // num_chunks) + 1 - return [items[i * chunk_size : (i + 1) * chunk_size] for i in range(num_chunks)] - - -class Solver2(Solver): +class Solver2: maze: Maze def __init__(self, maze: Maze) -> None: @@ -232,36 +173,36 @@ def build_nodes(self) -> list[Node]: return list(nodes.values()) - def solve(self) -> list[BasePath]: + def solve(self) -> int: nodes: list[Node] = self.build_nodes() - # print our nodes out: + print("\n".join(str(node) for node in nodes)) - first_path = NodePath() - first_path.add(0) - - unfinished_paths: list[NodePath] = [] - unfinished_paths.append(first_path) - - results, unfinished_paths = worker_solve(nodes, unfinished_paths, True, 0) - - # time for multithreading! - num_workers = 8 - unfinished_chunks: list[list[NodePath]] = split_list( - unfinished_paths, num_workers - ) - - with Pool(num_workers) as pool: - worker_args = [ - (nodes, unfinished_chunks[i], False, i) for i in range(num_workers) - ] - result_objects = pool.starmap_async(worker_solve, worker_args) - pool_results = result_objects.get() - for pool_result in pool_results: - paths = pool_result[0] - results.extend(paths) - print("\n" * num_workers * 2) # fix bug in progress bars - # split unfinished_paths: - results.sort(key=lambda x: len(x), reverse=True) - print("total results:", len(results)) - return results + return self.solve2(nodes, 0, len(nodes) - 1, 0, set()) + + def solve2( + self, + nodes: list[Node], + current: int, + destination: int, + distance: int, + seen: set[int], + ) -> int: + if current == destination: + return distance + + best = 0 + seen.add(current) + + for edge in nodes[current].edges: + neighbor, weight = edge.node2, edge.length + if neighbor in seen: + continue + + best = max( + best, self.solve2(nodes, neighbor, destination, distance + weight, seen) + ) + + seen.remove(current) + + return best diff --git a/day23/tests/test_day23.py b/day23/tests/test_day23.py index 6208a82..2d02aaa 100644 --- a/day23/tests/test_day23.py +++ b/day23/tests/test_day23.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING -from day23.day23 import INPUT_SMALL, part1 +from day23.day23 import INPUT_SMALL, part1, part2 from day23.lib.classes import Solver1 from day23.lib.parsers import get_maze @@ -23,3 +23,8 @@ def test_solver() -> None: def test_part1() -> None: maze: Maze = get_maze(INPUT_SMALL) assert part1(maze) == 94 + + +def test_part2() -> None: + maze: Maze = get_maze(INPUT_SMALL) + assert part2(maze) == 154