diff --git a/day20/day20.py b/day20/day20.py index b4e85bb..191e223 100644 --- a/day20/day20.py +++ b/day20/day20.py @@ -51,16 +51,6 @@ def simulate(modules: dict[str, BaseModule]) -> tuple[int, int]: return low, high -def get_final_gates(module_map: dict[str, BaseModule]) -> list[ConjunctionModule]: - """Should return vd, tp, pt, bk""" - result: list[ConjunctionModule] = [] - for module in module_map.values(): - if isinstance(module, ConjunctionModule) and len(module.inputs) >= 8: - result.append(module) - print([module.name for module in result]) - return result - - def get_loop_paths( start_switch: str, module_map: dict[str, BaseModule] ) -> list[BaseModule]: @@ -75,14 +65,15 @@ def get_loop_paths( current_module = module_map[current_module.outputs[0]] else: # return flipflop in outputs - flipflops = ( - output - for output in current_module.outputs - if isinstance(module_map[output], FlipFlopModule) - ) - current_module = module_map[next(flipflops)] + outputs: list[BaseModule] = [ + module_map[output] for output in current_module.outputs + ] + filtered = [node for node in outputs if isinstance(node, FlipFlopModule)] + assert len(filtered) == 1 + current_module = filtered[0] path.append(current_module) # should be a ConjunctionModule + assert isinstance(current_module, ConjunctionModule) return path @@ -108,13 +99,17 @@ def get_module_groups(module_map: dict[str, BaseModule]) -> ModuleGroups: get_loop_paths(node_name, module_map) for node_name in broadcaster.outputs ] loop_tails: list[ConjunctionModule] = [] + + # for each loop, the conjunction module goes to one other conjunction. + # grab this set of conjunctions, they are known as loop_tails for loop_path in loop_paths: last_node = loop_path[-1] - for node_name in last_node.outputs: - node = module_map[node_name] - if isinstance(node, ConjunctionModule): - loop_tails.append(node) - break + + nodes = [module_map[node_name] for node_name in last_node.outputs] + filtered = [node for node in nodes if isinstance(node, ConjunctionModule)] + assert len(filtered) == 1 + loop_tails.extend(filtered) + last_join_name = loop_tails[0].outputs[0] last_conjunction = get_typed_module(module_map, last_join_name, ConjunctionModule) sink = get_typed_module(module_map, "rx", SinkModule) diff --git a/day20/tests/test_day20.py b/day20/tests/test_day20.py index ccbdbd2..b807e18 100644 --- a/day20/tests/test_day20.py +++ b/day20/tests/test_day20.py @@ -1,4 +1,4 @@ -from day20.day20 import FILE_A, FILE_B, part1 +from day20.day20 import FILE_A, FILE_B, FILE_PROD, part1, part2 from day20.lib.parsers import get_modules @@ -8,3 +8,8 @@ def test_day20() -> None: modules = get_modules(FILE_B) assert part1(modules) == 11687500 + + +def test_part2() -> None: + modules = get_modules(FILE_PROD) + assert part2(modules)[0] == 252667369442479