diff --git a/BaseClasses.py b/BaseClasses.py index 6932147aad89..d7d6717f1545 100644 --- a/BaseClasses.py +++ b/BaseClasses.py @@ -826,21 +826,20 @@ class EntranceType(IntEnum): class Entrance: - access_rule: Callable[[CollectionState], bool] = staticmethod(lambda state: True) hide_path: bool = False player: int name: str parent_region: Optional[Region] connected_region: Optional[Region] = None - randomization_group: Hashable + randomization_group: int randomization_type: EntranceType # LttP specific, TODO: should make a LttPEntrance addresses = None target = None def __init__(self, player: int, name: str = "", parent: Region = None, - randomization_group: Hashable = 0, randomization_type: EntranceType = EntranceType.ONE_WAY): + randomization_group: int = 0, randomization_type: EntranceType = EntranceType.ONE_WAY): self.name = name self.parent_region = parent self.player = player @@ -861,18 +860,18 @@ def connect(self, region: Region, addresses: Any = None, target: Any = None) -> self.addresses = addresses region.entrances.append(self) - def is_valid_source_transition(self, state: "ERPlacementState") -> bool: + def is_valid_source_transition(self, er_state: "ERPlacementState") -> bool: """ Determines whether this is a valid source transition, that is, whether the entrance randomizer is allowed to pair it to place any other regions. By default, this is the same as a reachability check, but can be modified by Entrance implementations to add other restrictions based on the placement state. - :param state: The current (partial) state of the ongoing entrance randomization + :param er_state: The current (partial) state of the ongoing entrance randomization """ - return self.can_reach(state.collection_state) + return self.can_reach(er_state.collection_state) - def can_connect_to(self, other: Entrance, state: "ERPlacementState") -> bool: + def can_connect_to(self, other: Entrance, er_state: "ERPlacementState") -> bool: """ Determines whether a given Entrance is a valid target transition, that is, whether the entrance randomizer is allowed to pair this Entrance to that Entrance. By default, @@ -880,11 +879,11 @@ def can_connect_to(self, other: Entrance, state: "ERPlacementState") -> bool: two ways always go to two ways) and prevents connecting an exit to itself in coupled mode. :param other: The proposed Entrance to connect to - :param state: The current (partial) state of the ongoing entrance randomization + :param er_state: The current (partial) state of the ongoing entrance randomization """ # the implementation of coupled causes issues for self-loops since the reverse entrance will be the # same as the forward entrance. In uncoupled they are ok. - return self.randomization_type == other.randomization_type and (not state.coupled or self.name != other.name) + return self.randomization_type == other.randomization_type and (not er_state.coupled or self.name != other.name) def __repr__(self): return self.__str__() diff --git a/docs/entrance randomization.md b/docs/entrance randomization.md index c1c695173f54..b906b21a7a5a 100644 --- a/docs/entrance randomization.md +++ b/docs/entrance randomization.md @@ -219,7 +219,7 @@ randomized with other two-ways. You can set whether an `Entrance` is one-way or attribute. `Entrance`s can also set the `randomization_group` attribute to allow for grouping during randomization. This can be -any hashable object you define and may be based on player options. Some possible use cases for grouping include: +any integer you define and may be based on player options. Some possible use cases for grouping include: * Directional matching - only match leftward-facing transitions to rightward-facing ones * Terrain matching - only match water transitions to water transitions and land transitions to land transitions * Dungeon shuffle - only shuffle entrances within a dungeon/area with each other @@ -227,10 +227,6 @@ any hashable object you define and may be based on player options. Some possible By default, all `Entrance`s are placed in the group 0. -> [!NOTE] -> Throughout these docs, strings will be used as groups for simplicity and readability. In practice, using an int or -> IntEnum will be slightly more performant. - ### Calling generic ER Once you have defined all your entrances and exits and connected the Menu region to your region graph, you can call @@ -300,31 +296,51 @@ should connect with each other. This is done with the `target_group_lookup` and There is also a convenience function `bake_target_group_lookup` which can help to prepare group lookups when more complex group mapping logic is needed. Some recipes for `target_group_lookup` are presented here. +For the recipes below, assume the following groups (if the syntax used here is unfamiliar to you, "bit masking" and +"bitwise operators" would be the terms to search for): +```python +class Groups(IntEnum): + # Directions + LEFT = 1 + RIGHT = 2 + TOP = 3 + BOTTOM = 4 + DOOR = 5 + # Areas + FIELD = 1 << 3 + CAVE = 2 << 3 + MOUNTAIN = 3 << 3 + # Bitmasks + DIRECTION_MASK = FIELD - 1 + AREA_MASK = ~0 << 3 +``` + Directional matching: ```python direction_matching_group_lookup = { # with preserve_group_order = False, pair a left transition to either a right transition or door randomly # with preserve_group_order = True, pair a left transition to a right transition, or else a door if no # viable right transitions remain - "Left": ["Right", "Door"], + Groups.LEFT: [Groups.RIGHT, Groups.DOOR], # ... } ``` Terrain matching or dungeon shuffle: ```python -def randomize_within_same_group(group: str) -> List[str]: +def randomize_within_same_group(group: int) -> List[int]: return [group] identity_group_lookup = bake_target_group_lookup(world, randomize_within_same_group) ``` Directional + area shuffle: ```python -def get_target_groups(group: str) -> List[str]: - # example group: "Left-City" - # example result: ["Right-City", "Door-City"] - direction, area = group.split("-") - return [f"{pair_direction}-{area}" for pair_direction in direction_matching_group_lookup[direction]] +def get_target_groups(group: int) -> List[int]: + # example group: LEFT | CAVE + # example result: [RIGHT | CAVE, DOOR | CAVE] + direction = group & Groups.DIRECTION_MASK + area = group & Groups.AREA_MASK + return [pair_direction | area for pair_direction in direction_matching_group_lookup[direction]] target_group_lookup = bake_target_group_lookup(world, get_target_groups) ``` diff --git a/entrance_rando.py b/entrance_rando.py index d5302161696d..94a37187c08b 100644 --- a/entrance_rando.py +++ b/entrance_rando.py @@ -3,8 +3,7 @@ import random import time from collections import deque -from collections.abc import Hashable -from typing import Callable, Dict, Iterable, List, Tuple, Union, Set, Optional, Any +from typing import Callable, Dict, Iterable, List, Tuple, Set, Optional from BaseClasses import CollectionState, Entrance, Region, EntranceType from Options import Accessibility @@ -17,7 +16,7 @@ class EntranceRandomizationError(RuntimeError): class EntranceLookup: class GroupLookup: - _lookup: Dict[Hashable, List[Entrance]] + _lookup: Dict[int, List[Entrance]] def __init__(self): self._lookup = {} @@ -28,7 +27,7 @@ def __len__(self): def __bool__(self): return bool(self._lookup) - def __getitem__(self, item: Hashable) -> List[Entrance]: + def __getitem__(self, item: int) -> List[Entrance]: return self._lookup.get(item, []) def __iter__(self): @@ -104,7 +103,7 @@ def remove(self, entrance: Entrance) -> None: def get_targets( self, - groups: Iterable[Hashable], + groups: Iterable[int], dead_end: bool, preserve_group_order: bool ) -> Iterable[Entrance]: @@ -214,20 +213,21 @@ def connect( return [source_exit], [target_entrance] -def bake_target_group_lookup(world: World, get_target_groups: Callable[[Any], List[Hashable]]) \ - -> Dict[Hashable, List[Hashable]]: +def bake_target_group_lookup(world: World, get_target_groups: Callable[[int], List[int]]) \ + -> Dict[int, List[int]]: """ Applies a transformation to all known entrance groups on randomizable exists to build a group lookup table. :param world: Your World instance - :param get_target_groups: Function to call that returns the groups that a specific group type is allowed to connect to + :param get_target_groups: Function to call that returns the groups that a specific group type is allowed to + connect to """ unique_groups = { entrance.randomization_group for entrance in world.multiworld.get_entrances(world.player) if entrance.parent_region and not entrance.connected_region } return { group: get_target_groups(group) for group in unique_groups } -def disconnect_entrance_for_randomization(entrance: Entrance, target_group: Optional[Hashable] = None) -> None: +def disconnect_entrance_for_randomization(entrance: Entrance, target_group: Optional[int] = None) -> None: """ Given an entrance in a "vanilla" region graph, splits that entrance to prepare it for randomization in randomize_entrances. This should be done after setting the type and group of the entrance. @@ -258,7 +258,7 @@ def disconnect_entrance_for_randomization(entrance: Entrance, target_group: Opti def randomize_entrances( world: World, coupled: bool, - target_group_lookup: Dict[Hashable, List[Hashable]], + target_group_lookup: Dict[int, List[int]], preserve_group_order: bool = False, on_connect: Optional[Callable[[ERPlacementState, List[Entrance]], None]] = None ) -> ERPlacementState: diff --git a/test/general/test_entrance_rando.py b/test/general/test_entrance_rando.py index 03b345cd4336..9b8b8bb71556 100644 --- a/test/general/test_entrance_rando.py +++ b/test/general/test_entrance_rando.py @@ -1,4 +1,5 @@ import unittest +from enum import IntEnum from typing import List, Type from BaseClasses import Region, EntranceType, MultiWorld, Entrance @@ -9,7 +10,22 @@ from worlds.generic.Rules import set_rule -def generate_entrance_pair(region: Region, name_suffix: str, group: str): +class ERTestGroups(IntEnum): + LEFT = 1 + RIGHT = 2 + TOP = 3 + BOTTOM = 4 + + +directionally_matched_group_lookup = { + ERTestGroups.LEFT: [ERTestGroups.RIGHT], + ERTestGroups.RIGHT: [ERTestGroups.LEFT], + ERTestGroups.TOP: [ERTestGroups.BOTTOM], + ERTestGroups.BOTTOM: [ERTestGroups.TOP] +} + + +def generate_entrance_pair(region: Region, name_suffix: str, group: int): lx = region.create_exit(region.name + name_suffix) lx.randomization_group = group lx.randomization_type = EntranceType.TWO_WAY @@ -36,21 +52,13 @@ def generate_disconnected_region_grid(multiworld: MultiWorld, grid_side_length: if row == 0 and col == 0: multiworld.get_region("Menu", 1).connect(region) if col != 0: - generate_entrance_pair(region, "_left", "Left") + generate_entrance_pair(region, "_left", ERTestGroups.LEFT) if col != grid_side_length - 1: - generate_entrance_pair(region, "_right", "Right") + generate_entrance_pair(region, "_right", ERTestGroups.RIGHT) if row != 0: - generate_entrance_pair(region, "_top", "Top") + generate_entrance_pair(region, "_top", ERTestGroups.TOP) if row != grid_side_length - 1: - generate_entrance_pair(region, "_bottom", "Bottom") - - -directionally_matched_group_lookup = { - "Left": ["Right"], - "Right": ["Left"], - "Top": ["Bottom"], - "Bottom": ["Top"] -} + generate_entrance_pair(region, "_bottom", ERTestGroups.BOTTOM) class TestEntranceLookup(unittest.TestCase): @@ -65,7 +73,8 @@ def test_shuffled_targets(self): for entrance in er_targets: lookup.add(entrance) - retrieved_targets = lookup.get_targets(["Top", "Bottom"], False, False) + retrieved_targets = lookup.get_targets([ERTestGroups.TOP, ERTestGroups.BOTTOM], + False, False) prev = None group_order = [prev := group.randomization_group for group in retrieved_targets if prev != group.randomization_group] # technically possible that group order may not be shuffled, by some small chance, on some seeds. but generally @@ -84,10 +93,11 @@ def test_ordered_targets(self): for entrance in er_targets: lookup.add(entrance) - retrieved_targets = lookup.get_targets(["Top", "Bottom"], False, True) + retrieved_targets = lookup.get_targets([ERTestGroups.TOP, ERTestGroups.BOTTOM], + False, True) prev = None group_order = [prev := group.randomization_group for group in retrieved_targets if prev != group.randomization_group] - self.assertEqual(["Top", "Bottom"], group_order) + self.assertEqual([ERTestGroups.TOP, ERTestGroups.BOTTOM], group_order) class TestBakeTargetGroupLookup(unittest.TestCase): @@ -96,12 +106,12 @@ def test_lookup_generation(self): generate_disconnected_region_grid(multiworld, 5) world = multiworld.worlds[1] expected = { - "Left": ["tfeL"], - "Right": ["thgiR"], - "Top": ["poT"], - "Bottom": ["mottoB"] + ERTestGroups.LEFT: [-ERTestGroups.LEFT], + ERTestGroups.RIGHT: [-ERTestGroups.RIGHT], + ERTestGroups.TOP: [-ERTestGroups.TOP], + ERTestGroups.BOTTOM: [-ERTestGroups.BOTTOM] } - actual = bake_target_group_lookup(world, lambda s: [s[::-1]]) + actual = bake_target_group_lookup(world, lambda g: [-g]) self.assertEqual(expected, actual) @@ -112,7 +122,7 @@ def test_disconnect_default_2way(self): r2 = Region("r2", 1, multiworld) e = r1.create_exit("e") e.randomization_type = EntranceType.TWO_WAY - e.randomization_group = "Group1" + e.randomization_group = 1 e.connect(r2) disconnect_entrance_for_randomization(e) @@ -127,7 +137,7 @@ def test_disconnect_default_2way(self): self.assertIsNone(r1.entrances[0].parent_region) self.assertEqual("e", r1.entrances[0].name) self.assertEqual(EntranceType.TWO_WAY, r1.entrances[0].randomization_type) - self.assertEqual("Group1", r1.entrances[0].randomization_group) + self.assertEqual(1, r1.entrances[0].randomization_group) def test_disconnect_default_1way(self): multiworld = generate_test_multiworld() @@ -135,7 +145,7 @@ def test_disconnect_default_1way(self): r2 = Region("r2", 1, multiworld) e = r1.create_exit("e") e.randomization_type = EntranceType.ONE_WAY - e.randomization_group = "Group1" + e.randomization_group = 1 e.connect(r2) disconnect_entrance_for_randomization(e) @@ -150,7 +160,7 @@ def test_disconnect_default_1way(self): self.assertIsNone(r2.entrances[0].parent_region) self.assertEqual("r2", r2.entrances[0].name) self.assertEqual(EntranceType.ONE_WAY, r2.entrances[0].randomization_type) - self.assertEqual("Group1", r2.entrances[0].randomization_group) + self.assertEqual(1, r2.entrances[0].randomization_group) def test_disconnect_uses_alternate_group(self): multiworld = generate_test_multiworld() @@ -158,10 +168,10 @@ def test_disconnect_uses_alternate_group(self): r2 = Region("r2", 1, multiworld) e = r1.create_exit("e") e.randomization_type = EntranceType.ONE_WAY - e.randomization_group = "Group1" + e.randomization_group = 1 e.connect(r2) - disconnect_entrance_for_randomization(e, "Group2") + disconnect_entrance_for_randomization(e, 2) self.assertIsNone(e.connected_region) self.assertEqual([], r1.entrances) @@ -173,7 +183,7 @@ def test_disconnect_uses_alternate_group(self): self.assertIsNone(r2.entrances[0].parent_region) self.assertEqual("r2", r2.entrances[0].name) self.assertEqual(EntranceType.ONE_WAY, r2.entrances[0].randomization_type) - self.assertEqual("Group2", r2.entrances[0].randomization_group) + self.assertEqual(2, r2.entrances[0].randomization_group) class TestRandomizeEntrances(unittest.TestCase): @@ -257,10 +267,10 @@ def test_oneway_twoway_pairing(self): for index, region in enumerate(["region4", "region20", "region24"]): x = multiworld.get_region(region, 1).create_exit(f"{region}_bottom_1way") x.randomization_type = EntranceType.ONE_WAY - x.randomization_group = "Bottom" + x.randomization_group = ERTestGroups.BOTTOM e = region26.create_er_target(f"region26_top_1way{index}") e.randomization_type = EntranceType.ONE_WAY - e.randomization_group = "Top" + e.randomization_group = ERTestGroups.TOP result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup) for exit_name, entrance_name in result.pairings: @@ -333,7 +343,7 @@ def test_fails_when_some_unreachable_exit(self): def test_fails_when_some_unconnectable_exit(self): """tests that entrance randomization fails if an exit can't be made into a valid placement (non-minimal)""" class CustomEntrance(Entrance): - def can_connect_to(self, other: Entrance, state: "ERPlacementState") -> bool: + def can_connect_to(self, other: Entrance, er_state: "ERPlacementState") -> bool: if other.name == "region1_right": return False