diff --git a/day23/lib/classes2.py b/day23/lib/classes2.py index 7adf96d..6762217 100644 --- a/day23/lib/classes2.py +++ b/day23/lib/classes2.py @@ -1,12 +1,15 @@ """part 2 solution""" -import time from dataclasses import dataclass, field +from multiprocessing import Pool from queue import Queue import colorama +import tqdm from day23.lib.classes import BasePath, Maze, Path, Position, Solver +colorama.init(convert=True) + @dataclass(eq=True) class Node: @@ -70,6 +73,67 @@ def __len__(self) -> int: return self.path_length +def expand_node_path(node_path: NodePath, nodes: list[Node]) -> list[NodePath]: + """Expands a node path, giving back a list of of NodePaths""" + last_node: Node = nodes[node_path.last()] + result = [] + for edge in last_node.edges: + target_node_id: int = edge.node2 + if node_path.can_add(target_node_id): + to_add = node_path.copy() + to_add.add(target_node_id, len(edge.path)) + result.append(to_add) + return result + + +def worker_solve( + nodes: list[Node], + paths_to_process: list[NodePath], + break_early: bool, + thread_id: int, +) -> tuple[list[BasePath], list[NodePath]]: + results: list[BasePath] = [] + unfinished_paths: Queue[NodePath] = Queue() + for item in paths_to_process: + unfinished_paths.put(item) + + pbar = tqdm.tqdm( + desc=f"Thread{thread_id}", total=len(paths_to_process), position=thread_id + ) + if break_early: + pbar.total = 10000 + pbar.set_description("Initial run") + + while not unfinished_paths.empty(): + path = unfinished_paths.get() + node_id: int = path.last() + if node_id == nodes[-1].name: # end node + results.append(path) + + if break_early: + pbar.update() + if pbar.n % 10000 == 0: + break + + continue + + expansions = expand_node_path(path, nodes) + + if not break_early: + pbar.total += len(expansions) + pbar.update() + + for p in expansions: + unfinished_paths.put(p) + pbar.close() + return results, list(unfinished_paths.queue) + + +def split_list(items: list[NodePath], num_chunks: int) -> list[list[NodePath]]: + chunk_size = (len(items) // num_chunks) + 1 + return [items[i * chunk_size : (i + 1) * chunk_size] for i in range(num_chunks)] + + class Solver2(Solver): maze: Maze @@ -160,19 +224,6 @@ def expand_path(self, path: Path) -> list[Path]: return result - def expand_node_path( - self, node_path: NodePath, nodes: list[Node] - ) -> list[NodePath]: - last_node: Node = nodes[node_path.last()] - result = [] - for edge in last_node.edges: - target_node_id: int = edge.node2 - if node_path.can_add(target_node_id): - to_add = node_path.copy() - to_add.add(target_node_id, len(edge.path)) - result.append(to_add) - return result - def build_nodes(self) -> list[Node]: nodes: dict[Position, Node] = self.get_nodes() print(self.maze) @@ -183,29 +234,34 @@ def build_nodes(self) -> list[Node]: def solve(self) -> list[BasePath]: nodes: list[Node] = self.build_nodes() + # print our nodes out: + print("\n".join(str(node) for node in nodes)) + first_path = NodePath() first_path.add(0) - paths: Queue[NodePath] = Queue() - paths.put(first_path) - print("\n".join(str(node) for node in nodes)) - last = time.time() - results: list[BasePath] = [] - count = 0 - while not paths.empty(): - path: NodePath = paths.get() - node_id: int = path.last() - if node_id == nodes[-1].name: # end node - # reached an edge - count += 1 - results.append(path) - if count % 10000 == 0: - print(paths.qsize(), path, time.time() - last) - last = time.time() - continue - expansions = self.expand_node_path(path, nodes) - for path in expansions: - paths.put(path) + unfinished_paths: list[NodePath] = [] + unfinished_paths.append(first_path) + + results, unfinished_paths = worker_solve(nodes, unfinished_paths, True, 0) + + # time for multithreading! + num_workers = 8 + unfinished_chunks: list[list[NodePath]] = split_list( + unfinished_paths, num_workers + ) + + with Pool(num_workers) as pool: + worker_args = [ + (nodes, unfinished_chunks[i], False, i) for i in range(num_workers) + ] + result_objects = pool.starmap_async(worker_solve, worker_args) + pool_results = result_objects.get() + for pool_result in pool_results: + paths = pool_result[0] + results.extend(paths) + print("\n" * num_workers * 2) # fix bug in progress bars + # split unfinished_paths: results.sort(key=lambda x: len(x), reverse=True) - + print("total results:", len(results)) return results