diff --git a/BaseClasses.py b/BaseClasses.py index 7df178b79105..c2da3ab74fe6 100644 --- a/BaseClasses.py +++ b/BaseClasses.py @@ -9,7 +9,7 @@ import typing # this can go away when Python 3.8 support is dropped from argparse import Namespace from collections import Counter, deque -from collections.abc import Collection, MutableSequence +from collections.abc import Collection, MutableSequence, Hashable from enum import IntEnum, IntFlag from typing import Any, Callable, Dict, Iterable, Iterator, List, NamedTuple, Optional, Set, Tuple, TypedDict, Union, \ Type, ClassVar @@ -777,14 +777,14 @@ class Entrance: name: str parent_region: Optional[Region] connected_region: Optional[Region] = None - randomization_group: str + randomization_group: Hashable randomization_type: EntranceType # LttP specific, TODO: should make a LttPEntrance addresses = None target = None def __init__(self, player: int, name: str = "", parent: Region = None, - randomization_group: str = "Default", randomization_type: EntranceType = EntranceType.ONE_WAY): + randomization_group: Hashable = 0, randomization_type: EntranceType = EntranceType.ONE_WAY): self.name = name self.parent_region = parent self.player = player diff --git a/docs/entrance randomization.md b/docs/entrance randomization.md index 03a3205b9147..c1c695173f54 100644 --- a/docs/entrance randomization.md +++ b/docs/entrance randomization.md @@ -219,13 +219,17 @@ randomized with other two-ways. You can set whether an `Entrance` is one-way or attribute. `Entrance`s can also set the `randomization_group` attribute to allow for grouping during randomization. This can be -any arbitrary string you define and may be based on player options. Some possible use cases for grouping include: +any hashable object you define and may be based on player options. Some possible use cases for grouping include: * Directional matching - only match leftward-facing transitions to rightward-facing ones * Terrain matching - only match water transitions to water transitions and land transitions to land transitions * Dungeon shuffle - only shuffle entrances within a dungeon/area with each other * Combinations of the above -By default, all `Entrance`s are placed in the `"Default"` group. +By default, all `Entrance`s are placed in the group 0. + +> [!NOTE] +> Throughout these docs, strings will be used as groups for simplicity and readability. In practice, using an int or +> IntEnum will be slightly more performant. ### Calling generic ER @@ -292,24 +296,26 @@ graph LR #### Implementing grouping When you created your entrances, you defined the group each entrance belongs to. Now you will have to define how groups -should connect with each other. This is done with the `get_target_groups` and `preserve_group_order` parameters. Some -recipes for `get_target_groups` are presented here. +should connect with each other. This is done with the `target_group_lookup` and `preserve_group_order` parameters. +There is also a convenience function `bake_target_group_lookup` which can help to prepare group lookups when more +complex group mapping logic is needed. Some recipes for `target_group_lookup` are presented here. Directional matching: ```python -def match_direction(group: str) -> List[str]: - if group == "Left": - # with preserve_group_order = False, pair a left transition to either a right transition or door randomly - # with preserve_group_order = True, pair a left transition to a right transition, or else a door if no - # viable right transitions remain - return ["Right", "Door"] +direction_matching_group_lookup = { + # with preserve_group_order = False, pair a left transition to either a right transition or door randomly + # with preserve_group_order = True, pair a left transition to a right transition, or else a door if no + # viable right transitions remain + "Left": ["Right", "Door"], # ... +} ``` Terrain matching or dungeon shuffle: ```python def randomize_within_same_group(group: str) -> List[str]: return [group] +identity_group_lookup = bake_target_group_lookup(world, randomize_within_same_group) ``` Directional + area shuffle: @@ -318,7 +324,8 @@ def get_target_groups(group: str) -> List[str]: # example group: "Left-City" # example result: ["Right-City", "Door-City"] direction, area = group.split("-") - return [f"{pair_direction}-{area}" for pair_direction in match_direction(direction)] + return [f"{pair_direction}-{area}" for pair_direction in direction_matching_group_lookup[direction]] +target_group_lookup = bake_target_group_lookup(world, get_target_groups) ``` #### When to call `randomize_entrances` @@ -367,7 +374,7 @@ from Menu, similar to fill. ER then proceeds in stages to complete the randomiza The process for each connection will do the following: 1. Select a randomizable exit of a reachable region which is a valid source transition. -2. Get its group and call `get_target_groups` to determine which groups are valid targets. +2. Get its group and check `target_group_lookup` to determine which groups are valid targets. 3. Look up ER targets from those groups and find one which is valid according to `can_connect_to` 4. Connect the source exit to the target's target_region and delete the target. 5. If it's coupled mode, find the reverse exit and target by name and connect them as well. diff --git a/entrance_rando.py b/entrance_rando.py index 283dab5f5230..d5302161696d 100644 --- a/entrance_rando.py +++ b/entrance_rando.py @@ -3,7 +3,8 @@ import random import time from collections import deque -from typing import Callable, Dict, Iterable, List, Tuple, Union, Set, Optional +from collections.abc import Hashable +from typing import Callable, Dict, Iterable, List, Tuple, Union, Set, Optional, Any from BaseClasses import CollectionState, Entrance, Region, EntranceType from Options import Accessibility @@ -16,7 +17,7 @@ class EntranceRandomizationError(RuntimeError): class EntranceLookup: class GroupLookup: - _lookup: Dict[str, List[Entrance]] + _lookup: Dict[Hashable, List[Entrance]] def __init__(self): self._lookup = {} @@ -27,7 +28,7 @@ def __len__(self): def __bool__(self): return bool(self._lookup) - def __getitem__(self, item: str) -> List[Entrance]: + def __getitem__(self, item: Hashable) -> List[Entrance]: return self._lookup.get(item, []) def __iter__(self): @@ -103,7 +104,7 @@ def remove(self, entrance: Entrance) -> None: def get_targets( self, - groups: Iterable[str], + groups: Iterable[Hashable], dead_end: bool, preserve_group_order: bool ) -> Iterable[Entrance]: @@ -213,7 +214,20 @@ def connect( return [source_exit], [target_entrance] -def disconnect_entrance_for_randomization(entrance: Entrance, target_group: Optional[str] = None) -> None: +def bake_target_group_lookup(world: World, get_target_groups: Callable[[Any], List[Hashable]]) \ + -> Dict[Hashable, List[Hashable]]: + """ + Applies a transformation to all known entrance groups on randomizable exists to build a group lookup table. + + :param world: Your World instance + :param get_target_groups: Function to call that returns the groups that a specific group type is allowed to connect to + """ + unique_groups = { entrance.randomization_group for entrance in world.multiworld.get_entrances(world.player) + if entrance.parent_region and not entrance.connected_region } + return { group: get_target_groups(group) for group in unique_groups } + + +def disconnect_entrance_for_randomization(entrance: Entrance, target_group: Optional[Hashable] = None) -> None: """ Given an entrance in a "vanilla" region graph, splits that entrance to prepare it for randomization in randomize_entrances. This should be done after setting the type and group of the entrance. @@ -244,7 +258,7 @@ def disconnect_entrance_for_randomization(entrance: Entrance, target_group: Opti def randomize_entrances( world: World, coupled: bool, - get_target_groups: Callable[[str], List[str]], + target_group_lookup: Dict[Hashable, List[Hashable]], preserve_group_order: bool = False, on_connect: Optional[Callable[[ERPlacementState, List[Entrance]], None]] = None ) -> ERPlacementState: @@ -253,7 +267,9 @@ def randomize_entrances( :param world: Your World instance :param coupled: Whether connected entrances should be coupled to go in both directions - :param get_target_groups: Function to call that returns the groups that a specific group type is allowed to connect to + :param target_group_lookup: Map from each group to a list of the groups that it can be connect to. Every group + used on an exit must be provided and must map to at least one other group. The default + group is 0. :param preserve_group_order: Whether the order of groupings should be preserved for the returned target_groups :param on_connect: A callback function which allows specifying side effects after a placement is completed successfully and the underlying collection state has been updated. @@ -278,11 +294,7 @@ def find_pairing(dead_end: bool, require_new_regions: bool) -> bool: nonlocal perform_validity_check placeable_exits = er_state.find_placeable_exits(perform_validity_check) for source_exit in placeable_exits: - target_groups = get_target_groups(source_exit.randomization_group) - # anything can connect to the default group - if people don't like it the fix is to - # assign a non-default group - if "Default" not in target_groups: - target_groups.append("Default") + target_groups = target_group_lookup[source_exit.randomization_group] for target_entrance in entrance_lookup.get_targets(target_groups, dead_end, preserve_group_order): # requiring a new region is a proxy for enforcing new entrances are added, thus growing the search # space. this is not quite a full fidelity conversion, but doesn't seem to cause issues enough diff --git a/test/general/test_entrance_rando.py b/test/general/test_entrance_rando.py index 60ab1bce5d19..03b345cd4336 100644 --- a/test/general/test_entrance_rando.py +++ b/test/general/test_entrance_rando.py @@ -3,7 +3,7 @@ from BaseClasses import Region, EntranceType, MultiWorld, Entrance from entrance_rando import disconnect_entrance_for_randomization, randomize_entrances, EntranceRandomizationError, \ - ERPlacementState, EntranceLookup + ERPlacementState, EntranceLookup, bake_target_group_lookup from Options import Accessibility from test.general import generate_test_multiworld, generate_locations, generate_items from worlds.generic.Rules import set_rule @@ -45,17 +45,12 @@ def generate_disconnected_region_grid(multiworld: MultiWorld, grid_side_length: generate_entrance_pair(region, "_bottom", "Bottom") -def directionally_matched_group_selection(group: str) -> List[str]: - if group == "Left": - return ["Right"] - elif group == "Right": - return ["Left"] - elif group == "Top": - return ["Bottom"] - elif group == "Bottom": - return ["Top"] - else: - return [] +directionally_matched_group_lookup = { + "Left": ["Right"], + "Right": ["Left"], + "Top": ["Bottom"], + "Bottom": ["Top"] +} class TestEntranceLookup(unittest.TestCase): @@ -95,6 +90,21 @@ def test_ordered_targets(self): self.assertEqual(["Top", "Bottom"], group_order) +class TestBakeTargetGroupLookup(unittest.TestCase): + def test_lookup_generation(self): + multiworld = generate_test_multiworld() + generate_disconnected_region_grid(multiworld, 5) + world = multiworld.worlds[1] + expected = { + "Left": ["tfeL"], + "Right": ["thgiR"], + "Top": ["poT"], + "Bottom": ["mottoB"] + } + actual = bake_target_group_lookup(world, lambda s: [s[::-1]]) + self.assertEqual(expected, actual) + + class TestDisconnectForRandomization(unittest.TestCase): def test_disconnect_default_2way(self): multiworld = generate_test_multiworld() @@ -176,8 +186,8 @@ def test_determinism(self): multiworld2.worlds[1].random = multiworld2.per_slot_randoms[1] generate_disconnected_region_grid(multiworld2, 5) - result1 = randomize_entrances(multiworld1.worlds[1], False, directionally_matched_group_selection) - result2 = randomize_entrances(multiworld2.worlds[1], False, directionally_matched_group_selection) + result1 = randomize_entrances(multiworld1.worlds[1], False, directionally_matched_group_lookup) + result2 = randomize_entrances(multiworld2.worlds[1], False, directionally_matched_group_lookup) self.assertEqual(result1.pairings, result2.pairings) for e1, e2 in zip(result1.placements, result2.placements): self.assertEqual(e1.name, e2.name) @@ -190,7 +200,7 @@ def test_all_entrances_placed(self): multiworld.worlds[1].random = multiworld.per_slot_randoms[1] generate_disconnected_region_grid(multiworld, 5) - result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_selection) + result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup) self.assertEqual([], [entrance for region in multiworld.get_regions() for entrance in region.entrances if not entrance.parent_region]) @@ -215,7 +225,7 @@ def verify_coupled(state: ERPlacementState, placed_entrances: List[Entrance]): self.assertEqual(placed_entrances[0].parent_region, placed_entrances[1].connected_region) self.assertEqual(placed_entrances[1].parent_region, placed_entrances[0].connected_region) - result = randomize_entrances(multiworld.worlds[1], True, directionally_matched_group_selection, + result = randomize_entrances(multiworld.worlds[1], True, directionally_matched_group_lookup, on_connect=verify_coupled) # if we didn't visit every placement the verification on_connect doesn't really mean much self.assertEqual(len(result.placements), seen_placement_count) @@ -232,7 +242,7 @@ def verify_uncoupled(state: ERPlacementState, placed_entrances: List[Entrance]): seen_placement_count += len(placed_entrances) self.assertEqual(1, len(placed_entrances)) - result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_selection, + result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup, on_connect=verify_uncoupled) # if we didn't visit every placement the verification on_connect doesn't really mean much self.assertEqual(len(result.placements), seen_placement_count) @@ -252,7 +262,7 @@ def test_oneway_twoway_pairing(self): e.randomization_type = EntranceType.ONE_WAY e.randomization_group = "Top" - result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_selection) + result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup) for exit_name, entrance_name in result.pairings: # we have labeled our entrances in such a way that all the 1 way entrances have 1way in the name, # so test for that since the ER target will have been discarded @@ -265,7 +275,7 @@ def test_group_constraints_satisfied(self): multiworld.worlds[1].random = multiworld.per_slot_randoms[1] generate_disconnected_region_grid(multiworld, 5) - result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_selection) + result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup) for exit_name, entrance_name in result.pairings: # we have labeled our entrances in such a way that all the entrances contain their group in the name # so test for that since the ER target will have been discarded @@ -292,7 +302,7 @@ def test_minimal_entrance_rando(self): e = multiworld.get_entrance("region1_right", 1) set_rule(e, lambda state: False) - randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_selection) + randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup) self.assertEqual([], [entrance for region in multiworld.get_regions() for entrance in region.entrances if not entrance.parent_region]) @@ -307,7 +317,7 @@ def test_fails_when_mismatched_entrance_and_exit_count(self): multiworld.get_region("region1", 1).create_exit("extra") self.assertRaises(EntranceRandomizationError, randomize_entrances, multiworld.worlds[1], False, - directionally_matched_group_selection) + directionally_matched_group_lookup) def test_fails_when_some_unreachable_exit(self): """tests that entrance randomization fails if an exit is never reachable (non-minimal accessibility)""" @@ -318,7 +328,7 @@ def test_fails_when_some_unreachable_exit(self): set_rule(e, lambda state: False) self.assertRaises(EntranceRandomizationError, randomize_entrances, multiworld.worlds[1], False, - directionally_matched_group_selection) + directionally_matched_group_lookup) def test_fails_when_some_unconnectable_exit(self): """tests that entrance randomization fails if an exit can't be made into a valid placement (non-minimal)""" @@ -335,7 +345,7 @@ class CustomRegion(Region): generate_disconnected_region_grid(multiworld, 5, region_type=CustomRegion) self.assertRaises(EntranceRandomizationError, randomize_entrances, multiworld.worlds[1], False, - directionally_matched_group_selection) + directionally_matched_group_lookup) def test_minimal_er_fails_when_not_enough_locations_to_fit_progression(self): """ @@ -353,4 +363,4 @@ def test_minimal_er_fails_when_not_enough_locations_to_fit_progression(self): set_rule(e, lambda state: False) self.assertRaises(EntranceRandomizationError, randomize_entrances, multiworld.worlds[1], False, - directionally_matched_group_selection) + directionally_matched_group_lookup)