Skip to content

Commit

Permalink
Stardew Valley: Improve generation performance by around 11% by movin…
Browse files Browse the repository at this point in the history
…g calculating from rule evaluation to collect (ArchipelagoMW#4231)
  • Loading branch information
Jouramie committed Dec 17, 2024
1 parent 1a88fdd commit dafa079
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 56 deletions.
40 changes: 26 additions & 14 deletions worlds/stardew_valley/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from random import Random
from typing import Dict, Any, Iterable, Optional, Union, List, TextIO
from typing import Dict, Any, Iterable, Optional, List, TextIO

from BaseClasses import Region, Entrance, Location, Item, Tutorial, ItemClassification, MultiWorld, CollectionState
from Options import PerGameCommonOptions
Expand Down Expand Up @@ -94,7 +94,6 @@ class StardewValleyWorld(World):
randomized_entrances: Dict[str, str]

total_progression_items: int
excluded_from_total_progression_items: List[str] = [Event.received_walnuts]

def __init__(self, multiworld: MultiWorld, player: int):
super().__init__(multiworld, player)
Expand Down Expand Up @@ -182,7 +181,7 @@ def precollect_starting_season(self):

if self.options.season_randomization == SeasonRandomization.option_disabled:
for season in season_pool:
self.multiworld.push_precollected(self.create_starting_item(season))
self.multiworld.push_precollected(self.create_item(season))
return

if [item for item in self.multiworld.precollected_items[self.player]
Expand All @@ -192,12 +191,12 @@ def precollect_starting_season(self):
if self.options.season_randomization == SeasonRandomization.option_randomized_not_winter:
season_pool = [season for season in season_pool if season.name != "Winter"]

starting_season = self.create_starting_item(self.random.choice(season_pool))
starting_season = self.create_item(self.random.choice(season_pool))
self.multiworld.push_precollected(starting_season)

def precollect_farm_type_items(self):
if self.options.farm_type == FarmType.option_meadowlands and self.options.building_progression & BuildingProgression.option_progressive:
self.multiworld.push_precollected(self.create_starting_item("Progressive Coop"))
self.multiworld.push_precollected(self.create_item("Progressive Coop"))

def setup_player_events(self):
self.setup_action_events()
Expand Down Expand Up @@ -291,7 +290,7 @@ def setup_victory(self):
def get_all_location_names(self) -> List[str]:
return list(location.name for location in self.multiworld.get_locations(self.player))

def create_item(self, item: Union[str, ItemData], override_classification: ItemClassification = None) -> StardewItem:
def create_item(self, item: str | ItemData, override_classification: ItemClassification = None) -> StardewItem:
if isinstance(item, str):
item = item_table[item]

Expand All @@ -300,12 +299,6 @@ def create_item(self, item: Union[str, ItemData], override_classification: ItemC

return StardewItem(item.name, override_classification, item.code, self.player)

def create_starting_item(self, item: Union[str, ItemData]) -> StardewItem:
if isinstance(item, str):
item = item_table[item]

return StardewItem(item.name, item.classification, item.code, self.player)

def create_event_location(self, location_data: LocationData, rule: StardewRule = None, item: Optional[str] = None):
if rule is None:
rule = True_()
Expand Down Expand Up @@ -413,9 +406,19 @@ def collect(self, state: CollectionState, item: StardewItem) -> bool:
if not change:
return False

player_state = state.prog_items[self.player]

received_progression_count = player_state[Event.received_progression_item]
received_progression_count += 1
if self.total_progression_items:
# Total progression items is not set until all items are created, but collect will be called during the item creation when an item is precollected.
# We can't update the percentage if we don't know the total progression items, can't divide by 0.
player_state[Event.received_progression_percent] = received_progression_count * 100 // self.total_progression_items
player_state[Event.received_progression_item] = received_progression_count

walnut_amount = self.get_walnut_amount(item.name)
if walnut_amount:
state.prog_items[self.player][Event.received_walnuts] += walnut_amount
player_state[Event.received_walnuts] += walnut_amount

return True

Expand All @@ -424,9 +427,18 @@ def remove(self, state: CollectionState, item: StardewItem) -> bool:
if not change:
return False

player_state = state.prog_items[self.player]

received_progression_count = player_state[Event.received_progression_item]
received_progression_count -= 1
if self.total_progression_items:
# We can't update the percentage if we don't know the total progression items, can't divide by 0.
player_state[Event.received_progression_percent] = received_progression_count * 100 // self.total_progression_items
player_state[Event.received_progression_item] = received_progression_count

walnut_amount = self.get_walnut_amount(item.name)
if walnut_amount:
state.prog_items[self.player][Event.received_walnuts] -= walnut_amount
player_state[Event.received_walnuts] -= walnut_amount

return True

Expand Down
45 changes: 7 additions & 38 deletions worlds/stardew_valley/stardew_rule/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from BaseClasses import CollectionState
from .base import BaseStardewRule, CombinableStardewRule
from .protocol import StardewRule
from ..strings.ap_names.event_names import Event

if TYPE_CHECKING:
from .. import StardewValleyWorld
Expand Down Expand Up @@ -87,45 +88,13 @@ def __repr__(self):
return f"Reach {self.resolution_hint} {self.spot}"


@dataclass(frozen=True)
class HasProgressionPercent(CombinableStardewRule):
player: int
percent: int
class HasProgressionPercent(Received):
def __init__(self, player: int, percent: int):
super().__init__(Event.received_progression_percent, player, percent, event=True)

def __post_init__(self):
assert self.percent > 0, "HasProgressionPercent rule must be above 0%"
assert self.percent <= 100, "HasProgressionPercent rule can't require more than 100% of items"

@property
def combination_key(self) -> Hashable:
return HasProgressionPercent.__name__

@property
def value(self):
return self.percent

def __call__(self, state: CollectionState) -> bool:
stardew_world: "StardewValleyWorld" = state.multiworld.worlds[self.player]
total_count = stardew_world.total_progression_items
needed_count = (total_count * self.percent) // 100
player_state = state.prog_items[self.player]

if needed_count <= len(player_state) - len(stardew_world.excluded_from_total_progression_items):
return True

total_count = 0
for item, item_count in player_state.items():
if item in stardew_world.excluded_from_total_progression_items:
continue

total_count += item_count
if total_count >= needed_count:
return True

return False

def evaluate_while_simplifying(self, state: CollectionState) -> Tuple[StardewRule, bool]:
return self, self(state)
assert self.count > 0, "HasProgressionPercent rule must be above 0%"
assert self.count <= 100, "HasProgressionPercent rule can't require more than 100% of items"

def __repr__(self):
return f"Received {self.percent}% progression items"
return f"Received {self.count}% progression items"
2 changes: 2 additions & 0 deletions worlds/stardew_valley/strings/ap_names/event_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ class Event:
winter_farming = event("Winter Farming")

received_walnuts = event("Received Walnuts")
received_progression_item = event("Received Progression Item")
received_progression_percent = event("Received Progression Percent")
11 changes: 7 additions & 4 deletions worlds/stardew_valley/test/rules/TestShipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,17 @@ class TestShipsanityEverything(SVTestBase):
def test_all_shipsanity_locations_require_shipping_bin(self):
bin_name = "Shipping Bin"
self.collect_all_except(bin_name)
shipsanity_locations = [location for location in self.get_real_locations() if
LocationTags.SHIPSANITY in location_table[location.name].tags]
shipsanity_locations = [location
for location in self.get_real_locations()
if LocationTags.SHIPSANITY in location_table[location.name].tags]
bin_item = self.create_item(bin_name)

for location in shipsanity_locations:
with self.subTest(location.name):
self.remove(bin_item)
self.assertFalse(self.world.logic.region.can_reach_location(location.name)(self.multiworld.state))
self.multiworld.state.collect(bin_item)

self.collect(bin_item)
shipsanity_rule = self.world.logic.region.can_reach_location(location.name)
self.assert_rule_true(shipsanity_rule, self.multiworld.state)

self.remove(bin_item)

0 comments on commit dafa079

Please sign in to comment.