Skip to content

Commit

Permalink
Update groups to allow any hashable
Browse files Browse the repository at this point in the history
  • Loading branch information
BadMagic100 committed Jun 23, 2024
1 parent 743981d commit c89f185
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 51 deletions.
6 changes: 3 additions & 3 deletions BaseClasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import typing # this can go away when Python 3.8 support is dropped
from argparse import Namespace
from collections import Counter, deque
from collections.abc import Collection, MutableSequence
from collections.abc import Collection, MutableSequence, Hashable
from enum import IntEnum, IntFlag
from typing import Any, Callable, Dict, Iterable, Iterator, List, NamedTuple, Optional, Set, Tuple, TypedDict, Union, \
Type, ClassVar
Expand Down Expand Up @@ -777,14 +777,14 @@ class Entrance:
name: str
parent_region: Optional[Region]
connected_region: Optional[Region] = None
randomization_group: str
randomization_group: Hashable
randomization_type: EntranceType
# LttP specific, TODO: should make a LttPEntrance
addresses = None
target = None

def __init__(self, player: int, name: str = "", parent: Region = None,
randomization_group: str = "Default", randomization_type: EntranceType = EntranceType.ONE_WAY):
randomization_group: Hashable = 0, randomization_type: EntranceType = EntranceType.ONE_WAY):
self.name = name
self.parent_region = parent
self.player = player
Expand Down
31 changes: 19 additions & 12 deletions docs/entrance randomization.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,17 @@ randomized with other two-ways. You can set whether an `Entrance` is one-way or
attribute.

`Entrance`s can also set the `randomization_group` attribute to allow for grouping during randomization. This can be
any arbitrary string you define and may be based on player options. Some possible use cases for grouping include:
any hashable object you define and may be based on player options. Some possible use cases for grouping include:
* Directional matching - only match leftward-facing transitions to rightward-facing ones
* Terrain matching - only match water transitions to water transitions and land transitions to land transitions
* Dungeon shuffle - only shuffle entrances within a dungeon/area with each other
* Combinations of the above

By default, all `Entrance`s are placed in the `"Default"` group.
By default, all `Entrance`s are placed in the group 0.

> [!NOTE]
> Throughout these docs, strings will be used as groups for simplicity and readability. In practice, using an int or
> IntEnum will be slightly more performant.
### Calling generic ER

Expand Down Expand Up @@ -292,24 +296,26 @@ graph LR
#### Implementing grouping

When you created your entrances, you defined the group each entrance belongs to. Now you will have to define how groups
should connect with each other. This is done with the `get_target_groups` and `preserve_group_order` parameters. Some
recipes for `get_target_groups` are presented here.
should connect with each other. This is done with the `target_group_lookup` and `preserve_group_order` parameters.
There is also a convenience function `bake_target_group_lookup` which can help to prepare group lookups when more
complex group mapping logic is needed. Some recipes for `target_group_lookup` are presented here.

Directional matching:
```python
def match_direction(group: str) -> List[str]:
if group == "Left":
# with preserve_group_order = False, pair a left transition to either a right transition or door randomly
# with preserve_group_order = True, pair a left transition to a right transition, or else a door if no
# viable right transitions remain
return ["Right", "Door"]
direction_matching_group_lookup = {
# with preserve_group_order = False, pair a left transition to either a right transition or door randomly
# with preserve_group_order = True, pair a left transition to a right transition, or else a door if no
# viable right transitions remain
"Left": ["Right", "Door"],
# ...
}
```

Terrain matching or dungeon shuffle:
```python
def randomize_within_same_group(group: str) -> List[str]:
return [group]
identity_group_lookup = bake_target_group_lookup(world, randomize_within_same_group)
```

Directional + area shuffle:
Expand All @@ -318,7 +324,8 @@ def get_target_groups(group: str) -> List[str]:
# example group: "Left-City"
# example result: ["Right-City", "Door-City"]
direction, area = group.split("-")
return [f"{pair_direction}-{area}" for pair_direction in match_direction(direction)]
return [f"{pair_direction}-{area}" for pair_direction in direction_matching_group_lookup[direction]]
target_group_lookup = bake_target_group_lookup(world, get_target_groups)
```

#### When to call `randomize_entrances`
Expand Down Expand Up @@ -367,7 +374,7 @@ from Menu, similar to fill. ER then proceeds in stages to complete the randomiza

The process for each connection will do the following:
1. Select a randomizable exit of a reachable region which is a valid source transition.
2. Get its group and call `get_target_groups` to determine which groups are valid targets.
2. Get its group and check `target_group_lookup` to determine which groups are valid targets.
3. Look up ER targets from those groups and find one which is valid according to `can_connect_to`
4. Connect the source exit to the target's target_region and delete the target.
5. If it's coupled mode, find the reverse exit and target by name and connect them as well.
Expand Down
36 changes: 24 additions & 12 deletions entrance_rando.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import random
import time
from collections import deque
from typing import Callable, Dict, Iterable, List, Tuple, Union, Set, Optional
from collections.abc import Hashable
from typing import Callable, Dict, Iterable, List, Tuple, Union, Set, Optional, Any

from BaseClasses import CollectionState, Entrance, Region, EntranceType
from Options import Accessibility
Expand All @@ -16,7 +17,7 @@ class EntranceRandomizationError(RuntimeError):

class EntranceLookup:
class GroupLookup:
_lookup: Dict[str, List[Entrance]]
_lookup: Dict[Hashable, List[Entrance]]

def __init__(self):
self._lookup = {}
Expand All @@ -27,7 +28,7 @@ def __len__(self):
def __bool__(self):
return bool(self._lookup)

def __getitem__(self, item: str) -> List[Entrance]:
def __getitem__(self, item: Hashable) -> List[Entrance]:
return self._lookup.get(item, [])

def __iter__(self):
Expand Down Expand Up @@ -103,7 +104,7 @@ def remove(self, entrance: Entrance) -> None:

def get_targets(
self,
groups: Iterable[str],
groups: Iterable[Hashable],
dead_end: bool,
preserve_group_order: bool
) -> Iterable[Entrance]:
Expand Down Expand Up @@ -213,7 +214,20 @@ def connect(
return [source_exit], [target_entrance]


def disconnect_entrance_for_randomization(entrance: Entrance, target_group: Optional[str] = None) -> None:
def bake_target_group_lookup(world: World, get_target_groups: Callable[[Any], List[Hashable]]) \
-> Dict[Hashable, List[Hashable]]:
"""
Applies a transformation to all known entrance groups on randomizable exists to build a group lookup table.
:param world: Your World instance
:param get_target_groups: Function to call that returns the groups that a specific group type is allowed to connect to
"""
unique_groups = { entrance.randomization_group for entrance in world.multiworld.get_entrances(world.player)
if entrance.parent_region and not entrance.connected_region }
return { group: get_target_groups(group) for group in unique_groups }


def disconnect_entrance_for_randomization(entrance: Entrance, target_group: Optional[Hashable] = None) -> None:
"""
Given an entrance in a "vanilla" region graph, splits that entrance to prepare it for randomization
in randomize_entrances. This should be done after setting the type and group of the entrance.
Expand Down Expand Up @@ -244,7 +258,7 @@ def disconnect_entrance_for_randomization(entrance: Entrance, target_group: Opti
def randomize_entrances(
world: World,
coupled: bool,
get_target_groups: Callable[[str], List[str]],
target_group_lookup: Dict[Hashable, List[Hashable]],
preserve_group_order: bool = False,
on_connect: Optional[Callable[[ERPlacementState, List[Entrance]], None]] = None
) -> ERPlacementState:
Expand All @@ -253,7 +267,9 @@ def randomize_entrances(
:param world: Your World instance
:param coupled: Whether connected entrances should be coupled to go in both directions
:param get_target_groups: Function to call that returns the groups that a specific group type is allowed to connect to
:param target_group_lookup: Map from each group to a list of the groups that it can be connect to. Every group
used on an exit must be provided and must map to at least one other group. The default
group is 0.
:param preserve_group_order: Whether the order of groupings should be preserved for the returned target_groups
:param on_connect: A callback function which allows specifying side effects after a placement is completed
successfully and the underlying collection state has been updated.
Expand All @@ -278,11 +294,7 @@ def find_pairing(dead_end: bool, require_new_regions: bool) -> bool:
nonlocal perform_validity_check
placeable_exits = er_state.find_placeable_exits(perform_validity_check)
for source_exit in placeable_exits:
target_groups = get_target_groups(source_exit.randomization_group)
# anything can connect to the default group - if people don't like it the fix is to
# assign a non-default group
if "Default" not in target_groups:
target_groups.append("Default")
target_groups = target_group_lookup[source_exit.randomization_group]
for target_entrance in entrance_lookup.get_targets(target_groups, dead_end, preserve_group_order):
# requiring a new region is a proxy for enforcing new entrances are added, thus growing the search
# space. this is not quite a full fidelity conversion, but doesn't seem to cause issues enough
Expand Down
58 changes: 34 additions & 24 deletions test/general/test_entrance_rando.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from BaseClasses import Region, EntranceType, MultiWorld, Entrance
from entrance_rando import disconnect_entrance_for_randomization, randomize_entrances, EntranceRandomizationError, \
ERPlacementState, EntranceLookup
ERPlacementState, EntranceLookup, bake_target_group_lookup
from Options import Accessibility
from test.general import generate_test_multiworld, generate_locations, generate_items
from worlds.generic.Rules import set_rule
Expand Down Expand Up @@ -45,17 +45,12 @@ def generate_disconnected_region_grid(multiworld: MultiWorld, grid_side_length:
generate_entrance_pair(region, "_bottom", "Bottom")


def directionally_matched_group_selection(group: str) -> List[str]:
if group == "Left":
return ["Right"]
elif group == "Right":
return ["Left"]
elif group == "Top":
return ["Bottom"]
elif group == "Bottom":
return ["Top"]
else:
return []
directionally_matched_group_lookup = {
"Left": ["Right"],
"Right": ["Left"],
"Top": ["Bottom"],
"Bottom": ["Top"]
}


class TestEntranceLookup(unittest.TestCase):
Expand Down Expand Up @@ -95,6 +90,21 @@ def test_ordered_targets(self):
self.assertEqual(["Top", "Bottom"], group_order)


class TestBakeTargetGroupLookup(unittest.TestCase):
def test_lookup_generation(self):
multiworld = generate_test_multiworld()
generate_disconnected_region_grid(multiworld, 5)
world = multiworld.worlds[1]
expected = {
"Left": ["tfeL"],
"Right": ["thgiR"],
"Top": ["poT"],
"Bottom": ["mottoB"]
}
actual = bake_target_group_lookup(world, lambda s: [s[::-1]])
self.assertEqual(expected, actual)


class TestDisconnectForRandomization(unittest.TestCase):
def test_disconnect_default_2way(self):
multiworld = generate_test_multiworld()
Expand Down Expand Up @@ -176,8 +186,8 @@ def test_determinism(self):
multiworld2.worlds[1].random = multiworld2.per_slot_randoms[1]
generate_disconnected_region_grid(multiworld2, 5)

result1 = randomize_entrances(multiworld1.worlds[1], False, directionally_matched_group_selection)
result2 = randomize_entrances(multiworld2.worlds[1], False, directionally_matched_group_selection)
result1 = randomize_entrances(multiworld1.worlds[1], False, directionally_matched_group_lookup)
result2 = randomize_entrances(multiworld2.worlds[1], False, directionally_matched_group_lookup)
self.assertEqual(result1.pairings, result2.pairings)
for e1, e2 in zip(result1.placements, result2.placements):
self.assertEqual(e1.name, e2.name)
Expand All @@ -190,7 +200,7 @@ def test_all_entrances_placed(self):
multiworld.worlds[1].random = multiworld.per_slot_randoms[1]
generate_disconnected_region_grid(multiworld, 5)

result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_selection)
result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup)

self.assertEqual([], [entrance for region in multiworld.get_regions()
for entrance in region.entrances if not entrance.parent_region])
Expand All @@ -215,7 +225,7 @@ def verify_coupled(state: ERPlacementState, placed_entrances: List[Entrance]):
self.assertEqual(placed_entrances[0].parent_region, placed_entrances[1].connected_region)
self.assertEqual(placed_entrances[1].parent_region, placed_entrances[0].connected_region)

result = randomize_entrances(multiworld.worlds[1], True, directionally_matched_group_selection,
result = randomize_entrances(multiworld.worlds[1], True, directionally_matched_group_lookup,
on_connect=verify_coupled)
# if we didn't visit every placement the verification on_connect doesn't really mean much
self.assertEqual(len(result.placements), seen_placement_count)
Expand All @@ -232,7 +242,7 @@ def verify_uncoupled(state: ERPlacementState, placed_entrances: List[Entrance]):
seen_placement_count += len(placed_entrances)
self.assertEqual(1, len(placed_entrances))

result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_selection,
result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup,
on_connect=verify_uncoupled)
# if we didn't visit every placement the verification on_connect doesn't really mean much
self.assertEqual(len(result.placements), seen_placement_count)
Expand All @@ -252,7 +262,7 @@ def test_oneway_twoway_pairing(self):
e.randomization_type = EntranceType.ONE_WAY
e.randomization_group = "Top"

result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_selection)
result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup)
for exit_name, entrance_name in result.pairings:
# we have labeled our entrances in such a way that all the 1 way entrances have 1way in the name,
# so test for that since the ER target will have been discarded
Expand All @@ -265,7 +275,7 @@ def test_group_constraints_satisfied(self):
multiworld.worlds[1].random = multiworld.per_slot_randoms[1]
generate_disconnected_region_grid(multiworld, 5)

result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_selection)
result = randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup)
for exit_name, entrance_name in result.pairings:
# we have labeled our entrances in such a way that all the entrances contain their group in the name
# so test for that since the ER target will have been discarded
Expand All @@ -292,7 +302,7 @@ def test_minimal_entrance_rando(self):
e = multiworld.get_entrance("region1_right", 1)
set_rule(e, lambda state: False)

randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_selection)
randomize_entrances(multiworld.worlds[1], False, directionally_matched_group_lookup)

self.assertEqual([], [entrance for region in multiworld.get_regions()
for entrance in region.entrances if not entrance.parent_region])
Expand All @@ -307,7 +317,7 @@ def test_fails_when_mismatched_entrance_and_exit_count(self):
multiworld.get_region("region1", 1).create_exit("extra")

self.assertRaises(EntranceRandomizationError, randomize_entrances, multiworld.worlds[1], False,
directionally_matched_group_selection)
directionally_matched_group_lookup)

def test_fails_when_some_unreachable_exit(self):
"""tests that entrance randomization fails if an exit is never reachable (non-minimal accessibility)"""
Expand All @@ -318,7 +328,7 @@ def test_fails_when_some_unreachable_exit(self):
set_rule(e, lambda state: False)

self.assertRaises(EntranceRandomizationError, randomize_entrances, multiworld.worlds[1], False,
directionally_matched_group_selection)
directionally_matched_group_lookup)

def test_fails_when_some_unconnectable_exit(self):
"""tests that entrance randomization fails if an exit can't be made into a valid placement (non-minimal)"""
Expand All @@ -335,7 +345,7 @@ class CustomRegion(Region):
generate_disconnected_region_grid(multiworld, 5, region_type=CustomRegion)

self.assertRaises(EntranceRandomizationError, randomize_entrances, multiworld.worlds[1], False,
directionally_matched_group_selection)
directionally_matched_group_lookup)

def test_minimal_er_fails_when_not_enough_locations_to_fit_progression(self):
"""
Expand All @@ -353,4 +363,4 @@ def test_minimal_er_fails_when_not_enough_locations_to_fit_progression(self):
set_rule(e, lambda state: False)

self.assertRaises(EntranceRandomizationError, randomize_entrances, multiworld.worlds[1], False,
directionally_matched_group_selection)
directionally_matched_group_lookup)

0 comments on commit c89f185

Please sign in to comment.