From bfc018a0d033ea4ceea9d0741cf94e9f1a0e19a7 Mon Sep 17 00:00:00 2001 From: alexo Date: Sun, 24 Dec 2023 09:37:00 +1100 Subject: [PATCH] day23: multithreaded dfs --- day23/lib/classes2.py | 81 +++++++++++++++++++++++++++++++++---------- 1 file changed, 62 insertions(+), 19 deletions(-) diff --git a/day23/lib/classes2.py b/day23/lib/classes2.py index 967a51a..e0b4378 100644 --- a/day23/lib/classes2.py +++ b/day23/lib/classes2.py @@ -1,6 +1,11 @@ """part 2 solution""" +import math +import os +import time +from concurrent.futures import ProcessPoolExecutor as Pool from dataclasses import dataclass, field from queue import Queue +from typing import Any import colorama @@ -141,32 +146,70 @@ def solve(self) -> int: nodes: list[Node] = self.build_nodes() print("\n".join(str(node) for node in nodes)) + start = time.time() + cpu_count = os.cpu_count() or 2 + levels = int(math.log(cpu_count, 2)) + result = solve3(nodes, 0, len(nodes) - 1, 0, set(), levels) + print(time.time() - start) + return result + - return self.solve2(nodes, 0, len(nodes) - 1, 0, set()) +def solve3( + nodes: list[Node], + current: int, + destination: int, + distance: int, + seen: set[int], + forks_remaining: int, +) -> int: + if current == destination: + return distance - def solve2( - self, - nodes: list[Node], - current: int, - destination: int, - distance: int, - seen: set[int], - ) -> int: - if current == destination: - return distance + best = 0 + seen.add(current) - best = 0 - seen.add(current) + # Check if the current process is the main process + if forks_remaining == 0 or len(nodes[current].edges) == 1: for edge in nodes[current].edges: neighbor, weight = edge.node2, edge.length if neighbor in seen: continue - best = max( - best, self.solve2(nodes, neighbor, destination, distance + weight, seen) + result = solve3( + nodes, + neighbor, + destination, + distance + weight, + seen, + forks_remaining, ) - - seen.remove(current) - - return best + best = max(best, result) + else: # Use multiprocessing.Pool for parallel execution + with Pool(len(nodes[current].edges)) as pool: + tasks = [] + for edge in nodes[current].edges: + neighbor, weight = edge.node2, edge.length + if neighbor in seen: + continue + tasks.append( + [ + nodes, + neighbor, + destination, + distance + weight, + seen.copy(), + forks_remaining - 1, + ] + ) + for result in pool.map(solve3_helper, tasks): + best = max(best, result) + + seen.remove(current) + + return best + + +# ThreadPoolExecutor doesnt have starmap so we use a helper +def solve3_helper(args: list[Any]) -> int: + return solve3(*args)